瀏覽代碼

Add Saml2AuthenticationRequestRepository

Closes gh-9185
Marcus Da Coregio 4 年之前
父節點
當前提交
16e17d242e
共有 15 個文件被更改,包括 655 次插入17 次删除
  1. 26 3
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java
  2. 48 1
      config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java
  3. 30 0
      docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc
  4. 1 0
      saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle
  5. 37 4
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java
  6. 73 0
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepository.java
  7. 60 0
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/Saml2AuthenticationRequestRepository.java
  8. 32 1
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java
  9. 26 5
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java
  10. 26 1
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java
  11. 38 0
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestSaml2AuthenticationTokens.java
  12. 143 0
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepositoryTests.java
  13. 50 1
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java
  14. 41 1
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java
  15. 24 0
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java

+ 26 - 3
config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -33,6 +33,7 @@ import org.springframework.security.config.annotation.web.configurers.AbstractAu
 import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
 import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
 import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
 import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
 import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
@@ -40,6 +41,8 @@ import org.springframework.security.saml2.provider.service.authentication.OpenSa
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
@@ -206,6 +209,7 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 		}
 		}
 		this.saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter(getAuthenticationConverter(http),
 		this.saml2WebSsoAuthenticationFilter = new Saml2WebSsoAuthenticationFilter(getAuthenticationConverter(http),
 				this.loginProcessingUrl);
 				this.loginProcessingUrl);
+		setAuthenticationRequestRepository(http, this.saml2WebSsoAuthenticationFilter);
 		setAuthenticationFilter(this.saml2WebSsoAuthenticationFilter);
 		setAuthenticationFilter(this.saml2WebSsoAuthenticationFilter);
 		super.loginProcessingUrl(this.loginProcessingUrl);
 		super.loginProcessingUrl(this.loginProcessingUrl);
 		if (StringUtils.hasText(this.loginPage)) {
 		if (StringUtils.hasText(this.loginPage)) {
@@ -252,6 +256,11 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 		}
 		}
 	}
 	}
 
 
+	private void setAuthenticationRequestRepository(B http,
+			Saml2WebSsoAuthenticationFilter saml2WebSsoAuthenticationFilter) {
+		saml2WebSsoAuthenticationFilter.setAuthenticationRequestRepository(getAuthenticationRequestRepository(http));
+	}
+
 	private AuthenticationConverter getAuthenticationConverter(B http) {
 	private AuthenticationConverter getAuthenticationConverter(B http) {
 		if (this.authenticationConverter == null) {
 		if (this.authenticationConverter == null) {
 			return new Saml2AuthenticationTokenConverter(
 			return new Saml2AuthenticationTokenConverter(
@@ -311,6 +320,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 		return idps;
 		return idps;
 	}
 	}
 
 
+	private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> getAuthenticationRequestRepository(
+			B http) {
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getBeanOrNull(http,
+				Saml2AuthenticationRequestRepository.class);
+		if (repository == null) {
+			return new HttpSessionSaml2AuthenticationRequestRepository();
+		}
+		return repository;
+	}
+
 	private <C> C getSharedOrBean(B http, Class<C> clazz) {
 	private <C> C getSharedOrBean(B http, Class<C> clazz) {
 		C shared = http.getSharedObject(clazz);
 		C shared = http.getSharedObject(clazz);
 		if (shared != null) {
 		if (shared != null) {
@@ -348,8 +367,12 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 		private Filter build(B http) {
 		private Filter build(B http) {
 			Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
 			Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
 			Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
 			Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
-			return postProcess(
-					new Saml2WebSsoAuthenticationRequestFilter(contextResolver, authenticationRequestResolver));
+			Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = getAuthenticationRequestRepository(
+					http);
+			Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(contextResolver,
+					authenticationRequestResolver);
+			filter.setAuthenticationRequestRepository(repository);
+			return postProcess(filter);
 		}
 		}
 
 
 		private Saml2AuthenticationRequestFactory getResolver(B http) {
 		private Saml2AuthenticationRequestFactory getResolver(B http) {

+ 48 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

@@ -63,6 +63,7 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
 import org.springframework.security.saml2.core.Saml2Utils;
 import org.springframework.security.saml2.core.Saml2Utils;
 import org.springframework.security.saml2.core.TestSaml2X509Credentials;
 import org.springframework.security.saml2.core.TestSaml2X509Credentials;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
 import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider;
@@ -76,9 +77,11 @@ import org.springframework.security.saml2.provider.service.authentication.TestSa
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
 import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.FilterChainProxy;
+import org.springframework.security.web.SecurityFilterChain;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.context.HttpRequestResponseHolder;
 import org.springframework.security.web.context.HttpRequestResponseHolder;
@@ -237,6 +240,29 @@ public class Saml2LoginConfigurerTests {
 		assertThat(exception.getCause()).isInstanceOf(IOException.class);
 		assertThat(exception.getCause()).isInstanceOf(IOException.class);
 	}
 	}
 
 
+	@Test
+	public void authenticationRequestWhenCustomAuthnRequestRepositoryThenUses() throws Exception {
+		this.spring.register(CustomAuthenticationRequestRepository.class).autowire();
+		MockHttpServletRequestBuilder request = get("/saml2/authenticate/registration-id");
+		this.mvc.perform(request).andExpect(status().isFound());
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = this.spring.getContext()
+				.getBean(Saml2AuthenticationRequestRepository.class);
+		verify(repository).saveAuthenticationRequest(any(AbstractSaml2AuthenticationRequest.class),
+				any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
+	@Test
+	public void authenticateWhenCustomAuthnRequestRepositoryThenUses() throws Exception {
+		this.spring.register(CustomAuthenticationRequestRepository.class).autowire();
+		MockHttpServletRequestBuilder request = post("/login/saml2/sso/registration-id").param("SAMLResponse",
+				SIGNED_RESPONSE);
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = this.spring.getContext()
+				.getBean(Saml2AuthenticationRequestRepository.class);
+		this.mvc.perform(request);
+		verify(repository).loadAuthenticationRequest(any(HttpServletRequest.class));
+		verify(repository).removeAuthenticationRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
+	}
+
 	private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
 	private void validateSaml2WebSsoAuthenticationFilterConfiguration() {
 		// get the OpenSamlAuthenticationProvider
 		// get the OpenSamlAuthenticationProvider
 		Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
 		Saml2WebSsoAuthenticationFilter filter = getSaml2SsoFilter(this.springSecurityFilterChain);
@@ -371,7 +397,7 @@ public class Saml2LoginConfigurerTests {
 
 
 		@Bean
 		@Bean
 		Saml2AuthenticationRequestContextResolver resolver() {
 		Saml2AuthenticationRequestContextResolver resolver() {
-			return resolver;
+			return this.resolver;
 		}
 		}
 
 
 	}
 	}
@@ -420,6 +446,27 @@ public class Saml2LoginConfigurerTests {
 
 
 	}
 	}
 
 
+	@EnableWebSecurity
+	@Import(Saml2LoginConfigBeans.class)
+	static class CustomAuthenticationRequestRepository {
+
+		private static final Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> repository = mock(
+				Saml2AuthenticationRequestRepository.class);
+
+		@Bean
+		SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
+			http.authorizeRequests((authz) -> authz.anyRequest().authenticated());
+			http.saml2Login(withDefaults());
+			return http.build();
+		}
+
+		@Bean
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository() {
+			return this.repository;
+		}
+
+	}
+
 	static class Saml2LoginConfigBeans {
 	static class Saml2LoginConfigBeans {
 
 
 		@Bean
 		@Bean

+ 30 - 0
docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc

@@ -1610,3 +1610,33 @@ http {
 The success handler will send logout requests to the asserting party.
 The success handler will send logout requests to the asserting party.
 
 
 The request matcher will detect logout requests from the asserting party.
 The request matcher will detect logout requests from the asserting party.
+
+[[servlet-saml2login-store-authn-request]]
+=== Storing the `AuthnRequest`
+
+The `Saml2AuthenticationRequestRepository` is responsible for the persistence of the `AuthnRequest` from the time the `AuthnRequest` <<servlet-saml2login-sp-initiated-factory,is initiated>> to the time the `SAMLResponse` <<servlet-saml2login-authenticate-responses,is received>>.
+The `Saml2AuthenticationTokenConverter` is responsible for loading the `AuthnRequest` from the `Saml2AuthenticationRequestRepository` and saving it into the `Saml2AuthenticationToken`.
+
+The default implementation of `Saml2AuthenticationRequestRepository` is `HttpSessionSaml2AuthenticationRequestRepository`, which stores the `AuthnRequest` in the `HttpSession`.
+
+If you have a custom implementation of `Saml2AuthenticationRequestRepository`, you may configure it by exposing it as a `@Bean` as shown in the following example:
+
+====
+.Java
+[source,java,role="primary"]
+----
+@Bean
+Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository() {
+	return new CustomSaml2AuthenticationRequestRepository();
+}
+----
+
+.Kotlin
+[source,kotlin,role="secondary"]
+----
+@Bean
+open fun authenticationRequestRepository(): Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> {
+    return CustomSaml2AuthenticationRequestRepository()
+}
+----
+====

+ 1 - 0
saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle

@@ -55,6 +55,7 @@ dependencies {
 	testImplementation "org.junit.jupiter:junit-jupiter-params"
 	testImplementation "org.junit.jupiter:junit-jupiter-params"
 	testImplementation "org.junit.jupiter:junit-jupiter-engine"
 	testImplementation "org.junit.jupiter:junit-jupiter-engine"
 	testImplementation "org.mockito:mockito-core"
 	testImplementation "org.mockito:mockito-core"
+	testImplementation "org.mockito:mockito-inline"
 	testImplementation "org.mockito:mockito-junit-jupiter"
 	testImplementation "org.mockito:mockito-junit-jupiter"
 	testImplementation "org.springframework:spring-test"
 	testImplementation "org.springframework:spring-test"
 }
 }

+ 37 - 4
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -38,8 +38,10 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 
 
 	private final String saml2Response;
 	private final String saml2Response;
 
 
+	private final AbstractSaml2AuthenticationRequest authenticationRequest;
+
 	/**
 	/**
-	 * Creates a {@link Saml2AuthenticationToken} with the provided parameters
+	 * Creates a {@link Saml2AuthenticationToken} with the provided parameters.
 	 *
 	 *
 	 * Note that the given {@link RelyingPartyRegistration} should have all its templates
 	 * Note that the given {@link RelyingPartyRegistration} should have all its templates
 	 * resolved at this point. See
 	 * resolved at this point. See
@@ -48,15 +50,35 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 	 * @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to
 	 * @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to
 	 * use
 	 * use
 	 * @param saml2Response the SAML 2.0 response to authenticate
 	 * @param saml2Response the SAML 2.0 response to authenticate
+	 * @param authenticationRequest the {@code AuthNRequest} sent to the asserting party
 	 *
 	 *
-	 * @since 5.4
+	 * @since 5.6
 	 */
 	 */
-	public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) {
+	public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response,
+			AbstractSaml2AuthenticationRequest authenticationRequest) {
 		super(Collections.emptyList());
 		super(Collections.emptyList());
 		Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null");
 		Assert.notNull(relyingPartyRegistration, "relyingPartyRegistration cannot be null");
 		Assert.notNull(saml2Response, "saml2Response cannot be null");
 		Assert.notNull(saml2Response, "saml2Response cannot be null");
 		this.relyingPartyRegistration = relyingPartyRegistration;
 		this.relyingPartyRegistration = relyingPartyRegistration;
 		this.saml2Response = saml2Response;
 		this.saml2Response = saml2Response;
+		this.authenticationRequest = authenticationRequest;
+	}
+
+	/**
+	 * Creates a {@link Saml2AuthenticationToken} with the provided parameters
+	 *
+	 * Note that the given {@link RelyingPartyRegistration} should have all its templates
+	 * resolved at this point. See
+	 * {@link org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter}
+	 * for an example of performing that resolution.
+	 * @param relyingPartyRegistration the resolved {@link RelyingPartyRegistration} to
+	 * use
+	 * @param saml2Response the SAML 2.0 response to authenticate
+	 *
+	 * @since 5.4
+	 */
+	public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) {
+		this(relyingPartyRegistration, saml2Response, null);
 	}
 	}
 
 
 	/**
 	/**
@@ -81,6 +103,7 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 						.entityId(idpEntityId).singleSignOnServiceLocation(idpEntityId))
 						.entityId(idpEntityId).singleSignOnServiceLocation(idpEntityId))
 				.build();
 				.build();
 		this.saml2Response = saml2Response;
 		this.saml2Response = saml2Response;
+		this.authenticationRequest = null;
 	}
 	}
 
 
 	/**
 	/**
@@ -179,4 +202,14 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
 		return this.relyingPartyRegistration.getAssertingPartyDetails().getEntityId();
 		return this.relyingPartyRegistration.getAssertingPartyDetails().getEntityId();
 	}
 	}
 
 
+	/**
+	 * Returns the authentication request sent to the assertion party or {@code null} if
+	 * no authentication request is present
+	 * @return the authentication request sent to the assertion party
+	 * @since 5.6
+	 */
+	public AbstractSaml2AuthenticationRequest getAuthenticationRequest() {
+		return this.authenticationRequest;
+	}
+
 }
 }

+ 73 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepository.java

@@ -0,0 +1,73 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.servlet;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
+
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+
+/**
+ * A {@link Saml2AuthenticationRequestRepository} implementation that uses
+ * {@link HttpSession} to store and retrieve the
+ * {@link AbstractSaml2AuthenticationRequest}
+ *
+ * @author Marcus Da Coregio
+ * @since 5.6
+ */
+public class HttpSessionSaml2AuthenticationRequestRepository
+		implements Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> {
+
+	private static final String DEFAULT_SAML2_AUTHN_REQUEST_ATTR_NAME = HttpSessionSaml2AuthenticationRequestRepository.class
+			.getName().concat(".SAML2_AUTHN_REQUEST");
+
+	private String saml2AuthnRequestAttributeName = DEFAULT_SAML2_AUTHN_REQUEST_ATTR_NAME;
+
+	@Override
+	public AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
+		HttpSession httpSession = request.getSession(false);
+		if (httpSession == null) {
+			return null;
+		}
+		return (AbstractSaml2AuthenticationRequest) httpSession.getAttribute(this.saml2AuthnRequestAttributeName);
+	}
+
+	@Override
+	public void saveAuthenticationRequest(AbstractSaml2AuthenticationRequest authenticationRequest,
+			HttpServletRequest request, HttpServletResponse response) {
+		if (authenticationRequest == null) {
+			removeAuthenticationRequest(request, response);
+			return;
+		}
+		HttpSession httpSession = request.getSession();
+		httpSession.setAttribute(this.saml2AuthnRequestAttributeName, authenticationRequest);
+	}
+
+	@Override
+	public AbstractSaml2AuthenticationRequest removeAuthenticationRequest(HttpServletRequest request,
+			HttpServletResponse response) {
+		AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
+		if (authenticationRequest == null) {
+			return null;
+		}
+		HttpSession httpSession = request.getSession();
+		httpSession.removeAttribute(this.saml2AuthnRequestAttributeName);
+		return authenticationRequest;
+	}
+
+}

+ 60 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/Saml2AuthenticationRequestRepository.java

@@ -0,0 +1,60 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.servlet;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+
+/**
+ * A repository for {@link AbstractSaml2AuthenticationRequest}
+ *
+ * @param <T> the type of SAML 2.0 Authentication Request
+ * @author Marcus Da Coregio
+ * @since 5.6
+ */
+public interface Saml2AuthenticationRequestRepository<T extends AbstractSaml2AuthenticationRequest> {
+
+	/**
+	 * Loads the {@link AbstractSaml2AuthenticationRequest} from the request
+	 * @param request the current request
+	 * @return the {@link AbstractSaml2AuthenticationRequest} or {@code null} if it is not
+	 * present
+	 */
+	T loadAuthenticationRequest(HttpServletRequest request);
+
+	/**
+	 * Saves the current authentication request using the {@link HttpServletRequest} and
+	 * {@link HttpServletResponse}
+	 * @param authenticationRequest the {@link AbstractSaml2AuthenticationRequest}
+	 * @param request the current request
+	 * @param response the current response
+	 */
+	void saveAuthenticationRequest(T authenticationRequest, HttpServletRequest request, HttpServletResponse response);
+
+	/**
+	 * Removes the authentication request using the {@link HttpServletRequest} and
+	 * {@link HttpServletResponse}
+	 * @param request the current request
+	 * @param response the current response
+	 * @return the removed {@link AbstractSaml2AuthenticationRequest} or {@code null} if
+	 * it is not present
+	 */
+	T removeAuthenticationRequest(HttpServletRequest request, HttpServletResponse response);
+
+}

+ 32 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -23,8 +23,11 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.saml2.core.Saml2Error;
 import org.springframework.security.saml2.core.Saml2Error;
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
@@ -42,6 +45,8 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
 
 
 	private final AuthenticationConverter authenticationConverter;
 	private final AuthenticationConverter authenticationConverter;
 
 
+	private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository();
+
 	/**
 	/**
 	 * Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is
 	 * Creates a {@code Saml2WebSsoAuthenticationFilter} authentication filter that is
 	 * configured to use the {@link #DEFAULT_FILTER_PROCESSES_URI} processing URL
 	 * configured to use the {@link #DEFAULT_FILTER_PROCESSES_URI} processing URL
@@ -100,7 +105,33 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
 					"No relying party registration found");
 					"No relying party registration found");
 			throw new Saml2AuthenticationException(saml2Error);
 			throw new Saml2AuthenticationException(saml2Error);
 		}
 		}
+		this.authenticationRequestRepository.removeAuthenticationRequest(request, response);
 		return getAuthenticationManager().authenticate(authentication);
 		return getAuthenticationManager().authenticate(authentication);
 	}
 	}
 
 
+	/**
+	 * Use the given {@link Saml2AuthenticationRequestRepository} to remove the saved
+	 * authentication request. If the {@link #authenticationConverter} is of the type
+	 * {@link Saml2AuthenticationTokenConverter}, the
+	 * {@link Saml2AuthenticationRequestRepository} will also be set into the
+	 * {@link #authenticationConverter}.
+	 * @param authenticationRequestRepository the
+	 * {@link Saml2AuthenticationRequestRepository} to use
+	 * @since 5.6
+	 */
+	public void setAuthenticationRequestRepository(
+			Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
+		Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
+		this.authenticationRequestRepository = authenticationRequestRepository;
+		setAuthenticationRequestRepositoryIntoAuthenticationConverter(authenticationRequestRepository);
+	}
+
+	private void setAuthenticationRequestRepositoryIntoAuthenticationConverter(
+			Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
+		if (this.authenticationConverter instanceof Saml2AuthenticationTokenConverter) {
+			Saml2AuthenticationTokenConverter authenticationTokenConverter = (Saml2AuthenticationTokenConverter) this.authenticationConverter;
+			authenticationTokenConverter.setAuthenticationRequestRepository(authenticationRequestRepository);
+		}
+	}
+
 }
 }

+ 26 - 5
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java

@@ -27,6 +27,7 @@ import javax.servlet.http.HttpServletResponse;
 import org.opensaml.core.Version;
 import org.opensaml.core.Version;
 
 
 import org.springframework.http.MediaType;
 import org.springframework.http.MediaType;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
@@ -34,6 +35,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
+import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
 import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
@@ -79,6 +82,8 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 
 
 	private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
 	private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
 
 
+	private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository();
+
 	/**
 	/**
 	 * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided
 	 * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided
 	 * parameters
 	 * parameters
@@ -149,6 +154,19 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 		this.redirectMatcher = redirectMatcher;
 		this.redirectMatcher = redirectMatcher;
 	}
 	}
 
 
+	/**
+	 * Use the given {@link Saml2AuthenticationRequestRepository} to save the
+	 * authentication request
+	 * @param authenticationRequestRepository the
+	 * {@link Saml2AuthenticationRequestRepository} to use
+	 * @since 5.6
+	 */
+	public void setAuthenticationRequestRepository(
+			Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
+		Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
+		this.authenticationRequestRepository = authenticationRequestRepository;
+	}
+
 	@Override
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
 			throws ServletException, IOException {
@@ -165,17 +183,18 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 		}
 		}
 		RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration();
 		RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration();
 		if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
 		if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
-			sendRedirect(response, context);
+			sendRedirect(request, response, context);
 		}
 		}
 		else {
 		else {
-			sendPost(response, context);
+			sendPost(request, response, context);
 		}
 		}
 	}
 	}
 
 
-	private void sendRedirect(HttpServletResponse response, Saml2AuthenticationRequestContext context)
-			throws IOException {
+	private void sendRedirect(HttpServletRequest request, HttpServletResponse response,
+			Saml2AuthenticationRequestContext context) throws IOException {
 		Saml2RedirectAuthenticationRequest authenticationRequest = this.authenticationRequestFactory
 		Saml2RedirectAuthenticationRequest authenticationRequest = this.authenticationRequestFactory
 				.createRedirectAuthenticationRequest(context);
 				.createRedirectAuthenticationRequest(context);
+		this.authenticationRequestRepository.saveAuthenticationRequest(authenticationRequest, request, response);
 		UriComponentsBuilder uriBuilder = UriComponentsBuilder
 		UriComponentsBuilder uriBuilder = UriComponentsBuilder
 				.fromUriString(authenticationRequest.getAuthenticationRequestUri());
 				.fromUriString(authenticationRequest.getAuthenticationRequestUri());
 		addParameter("SAMLRequest", authenticationRequest.getSamlRequest(), uriBuilder);
 		addParameter("SAMLRequest", authenticationRequest.getSamlRequest(), uriBuilder);
@@ -194,9 +213,11 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 		}
 		}
 	}
 	}
 
 
-	private void sendPost(HttpServletResponse response, Saml2AuthenticationRequestContext context) throws IOException {
+	private void sendPost(HttpServletRequest request, HttpServletResponse response,
+			Saml2AuthenticationRequestContext context) throws IOException {
 		Saml2PostAuthenticationRequest authenticationRequest = this.authenticationRequestFactory
 		Saml2PostAuthenticationRequest authenticationRequest = this.authenticationRequestFactory
 				.createPostAuthenticationRequest(context);
 				.createPostAuthenticationRequest(context);
+		this.authenticationRequestRepository.saveAuthenticationRequest(authenticationRequest, request, response);
 		String html = createSamlPostRequestFormData(authenticationRequest);
 		String html = createSamlPostRequestFormData(authenticationRequest);
 		response.setContentType(MediaType.TEXT_HTML_VALUE);
 		response.setContentType(MediaType.TEXT_HTML_VALUE);
 		response.getWriter().write(html);
 		response.getWriter().write(html);

+ 26 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java

@@ -18,6 +18,7 @@ package org.springframework.security.saml2.provider.service.web;
 
 
 import java.io.ByteArrayOutputStream;
 import java.io.ByteArrayOutputStream;
 import java.nio.charset.StandardCharsets;
 import java.nio.charset.StandardCharsets;
+import java.util.function.Function;
 import java.util.zip.Inflater;
 import java.util.zip.Inflater;
 import java.util.zip.InflaterOutputStream;
 import java.util.zip.InflaterOutputStream;
 
 
@@ -30,9 +31,12 @@ import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpMethod;
 import org.springframework.security.saml2.core.Saml2Error;
 import org.springframework.security.saml2.core.Saml2Error;
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
 
 
@@ -50,6 +54,8 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 
 
 	private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
 	private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
 
 
+	private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
+
 	/**
 	/**
 	 * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
 	 * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
 	 * resolving {@link RelyingPartyRegistration}s
 	 * resolving {@link RelyingPartyRegistration}s
@@ -60,6 +66,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 			Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
 			Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
 		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
 		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
 		this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
 		this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
+		this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
 	}
 	}
 
 
 	@Override
 	@Override
@@ -74,7 +81,25 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 		}
 		}
 		byte[] b = samlDecode(saml2Response);
 		byte[] b = samlDecode(saml2Response);
 		saml2Response = inflateIfRequired(request, b);
 		saml2Response = inflateIfRequired(request, b);
-		return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response);
+		AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
+		return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest);
+	}
+
+	/**
+	 * Use the given {@link Saml2AuthenticationRequestRepository} to load authentication
+	 * request.
+	 * @param authenticationRequestRepository the
+	 * {@link Saml2AuthenticationRequestRepository} to use
+	 * @since 5.6
+	 */
+	public void setAuthenticationRequestRepository(
+			Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
+		Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
+		this.loader = authenticationRequestRepository::loadAuthenticationRequest;
+	}
+
+	private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
+		return this.loader.apply(request);
 	}
 	}
 
 
 	private String inflateIfRequired(HttpServletRequest request, byte[] b) {
 	private String inflateIfRequired(HttpServletRequest request, byte[] b) {

+ 38 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestSaml2AuthenticationTokens.java

@@ -0,0 +1,38 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.authentication;
+
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
+
+/**
+ * Tests instances for {@link Saml2AuthenticationToken}
+ *
+ * @author Marcus Da Coregio
+ */
+public final class TestSaml2AuthenticationTokens {
+
+	private TestSaml2AuthenticationTokens() {
+	}
+
+	public static Saml2AuthenticationToken token() {
+		RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.relyingPartyRegistration()
+				.build();
+		return new Saml2AuthenticationToken(relyingPartyRegistration, "saml2-xml-response-object");
+	}
+
+}

+ 143 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/HttpSessionSaml2AuthenticationRequestRepositoryTests.java

@@ -0,0 +1,143 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.servlet;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.mock.web.MockHttpSession;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+
+/**
+ * @author Marcus Da Coregio
+ */
+public class HttpSessionSaml2AuthenticationRequestRepositoryTests {
+
+	private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO";
+
+	private MockHttpServletRequest request;
+
+	private MockHttpServletResponse response;
+
+	private HttpSessionSaml2AuthenticationRequestRepository authenticationRequestRepository;
+
+	@BeforeEach
+	public void setup() {
+		this.request = new MockHttpServletRequest();
+		this.response = new MockHttpServletResponse();
+		this.authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository();
+	}
+
+	@Test
+	public void loadAuthenticationRequestWhenInvalidSessionThenNull() {
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest).isNull();
+	}
+
+	@Test
+	public void loadAuthenticationRequestWhenNoAttributeInSessionThenNull() {
+		this.request.getSession();
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest).isNull();
+	}
+
+	@Test
+	public void loadAuthenticationRequestWhenAttributeInSessionThenReturnsAuthenticationRequest() {
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		given(mockAuthenticationRequest.getAuthenticationRequestUri()).willReturn(IDP_SSO_URL);
+		this.request.getSession();
+		this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request,
+				this.response);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest.getAuthenticationRequestUri()).isEqualTo(IDP_SSO_URL);
+	}
+
+	@Test
+	public void saveAuthenticationRequestWhenSessionDontExistsThenCreateAndSave() {
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request,
+				this.response);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest).isNotNull();
+	}
+
+	@Test
+	public void saveAuthenticationRequestWhenSessionExistsThenSave() {
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		this.request.getSession();
+		this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request,
+				this.response);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest).isNotNull();
+	}
+
+	@Test
+	public void saveAuthenticationRequestWhenNullAuthenticationRequestThenDontSave() {
+		this.request.getSession();
+		this.authenticationRequestRepository.saveAuthenticationRequest(null, this.request, this.response);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest).isNull();
+	}
+
+	@Test
+	public void removeAuthenticationRequestWhenInvalidSessionThenReturnNull() {
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.removeAuthenticationRequest(this.request, this.response);
+		assertThat(authenticationRequest).isNull();
+	}
+
+	@Test
+	public void removeAuthenticationRequestWhenAttributeInSessionThenRemoveAuthenticationRequest() {
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		given(mockAuthenticationRequest.getAuthenticationRequestUri()).willReturn(IDP_SSO_URL);
+		this.request.getSession();
+		this.authenticationRequestRepository.saveAuthenticationRequest(mockAuthenticationRequest, this.request,
+				this.response);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.removeAuthenticationRequest(this.request, this.response);
+		AbstractSaml2AuthenticationRequest authenticationRequestAfterRemove = this.authenticationRequestRepository
+				.loadAuthenticationRequest(this.request);
+		assertThat(authenticationRequest.getAuthenticationRequestUri()).isEqualTo(IDP_SSO_URL);
+		assertThat(authenticationRequestAfterRemove).isNull();
+	}
+
+	@Test
+	public void removeAuthenticationRequestWhenValidSessionNoAttributeThenReturnsNull() {
+		MockHttpSession session = mock(MockHttpSession.class);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setSession(session);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+				.removeAuthenticationRequest(request, this.response);
+		verify(session).getAttribute(anyString());
+		assertThat(authenticationRequest).isNull();
+	}
+
+}

+ 50 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -24,12 +24,20 @@ import org.junit.jupiter.api.Test;
 
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
+import org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationTokens;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
+import org.springframework.security.web.authentication.AuthenticationConverter;
 
 
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
 
 
 public class Saml2WebSsoAuthenticationFilterTests {
 public class Saml2WebSsoAuthenticationFilterTests {
 
 
@@ -84,4 +92,45 @@ public class Saml2WebSsoAuthenticationFilterTests {
 				.withMessage("No relying party registration found");
 				.withMessage("No relying party registration found");
 	}
 	}
 
 
+	@Test
+	public void attemptAuthenticationWhenSavedAuthnRequestThenRemovesAuthnRequest() {
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
+				Saml2AuthenticationRequestRepository.class);
+		AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class);
+		given(authenticationConverter.convert(this.request)).willReturn(TestSaml2AuthenticationTokens.token());
+		this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, "/some/other/path/{registrationId}");
+		this.filter.setAuthenticationManager((authentication) -> null);
+		this.request.setPathInfo("/some/other/path/idp-registration-id");
+		this.filter.setAuthenticationRequestRepository(authenticationRequestRepository);
+		this.filter.attemptAuthentication(this.request, this.response);
+		verify(authenticationRequestRepository).removeAuthenticationRequest(this.request, this.response);
+	}
+
+	@Test
+	public void setAuthenticationRequestRepositoryWhenNullThenThrowsIllegalArgument() {
+		assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationRequestRepository(null))
+				.withMessage("authenticationRequestRepository cannot be null");
+	}
+
+	@Test
+	public void setAuthenticationRequestRepositoryWhenExpectedAuthenticationConverterTypeThenSetLoaderIntoConverter() {
+		Saml2AuthenticationTokenConverter authenticationConverterMock = mock(Saml2AuthenticationTokenConverter.class);
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
+				Saml2AuthenticationRequestRepository.class);
+		this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverterMock,
+				"/some/other/path/{registrationId}");
+		this.filter.setAuthenticationRequestRepository(authenticationRequestRepository);
+		verify(authenticationConverterMock).setAuthenticationRequestRepository(authenticationRequestRepository);
+	}
+
+	@Test
+	public void setAuthenticationRequestRepositoryWhenNotExpectedAuthenticationConverterTypeThenDontSet() {
+		AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class);
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
+				Saml2AuthenticationRequestRepository.class);
+		this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, "/some/other/path/{registrationId}");
+		this.filter.setAuthenticationRequestRepository(authenticationRequestRepository);
+		verifyNoInteractions(authenticationConverter);
+	}
+
 }
 }

+ 41 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -28,6 +28,7 @@ import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
 import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
@@ -36,6 +37,7 @@ import org.springframework.security.saml2.provider.service.authentication.TestSa
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.web.util.HtmlUtils;
 import org.springframework.web.util.HtmlUtils;
 import org.springframework.web.util.UriUtils;
 import org.springframework.web.util.UriUtils;
@@ -43,6 +45,7 @@ import org.springframework.web.util.UriUtils;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verify;
@@ -60,6 +63,9 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
 
 
 	private Saml2AuthenticationRequestContextResolver resolver = mock(Saml2AuthenticationRequestContextResolver.class);
 	private Saml2AuthenticationRequestContextResolver resolver = mock(Saml2AuthenticationRequestContextResolver.class);
 
 
+	private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
+			Saml2AuthenticationRequestRepository.class);
+
 	private MockHttpServletRequest request;
 	private MockHttpServletRequest request;
 
 
 	private MockHttpServletResponse response;
 	private MockHttpServletResponse response;
@@ -79,6 +85,7 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
 				.providerDetails((c) -> c.entityId("idp-entity-id")).providerDetails((c) -> c.webSsoUrl(IDP_SSO_URL))
 				.providerDetails((c) -> c.entityId("idp-entity-id")).providerDetails((c) -> c.webSsoUrl(IDP_SSO_URL))
 				.assertionConsumerServiceUrlTemplate("template")
 				.assertionConsumerServiceUrlTemplate("template")
 				.credentials((c) -> c.add(TestSaml2X509Credentials.assertingPartyPrivateCredential()));
 				.credentials((c) -> c.add(TestSaml2X509Credentials.assertingPartyPrivateCredential()));
+		this.filter.setAuthenticationRequestRepository(this.authenticationRequestRepository);
 	}
 	}
 
 
 	@Test
 	@Test
@@ -216,4 +223,37 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
 		assertThat(this.response.getStatus()).isEqualTo(401);
 		assertThat(this.response.getStatus()).isEqualTo(401);
 	}
 	}
 
 
+	@Test
+	public void setAuthenticationRequestRepositoryWhenNullThenException() {
+		Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.resolver,
+				this.factory);
+		assertThatIllegalArgumentException().isThrownBy(() -> filter.setAuthenticationRequestRepository(null));
+	}
+
+	@Test
+	public void doFilterWhenRedirectThenSaveRedirectRequest() throws ServletException, IOException {
+		Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
+		Saml2RedirectAuthenticationRequest request = redirectAuthenticationRequest(context).build();
+		given(this.resolver.resolve(any())).willReturn(context);
+		given(this.factory.createRedirectAuthenticationRequest(any())).willReturn(request);
+		this.filter.doFilterInternal(this.request, this.response, this.filterChain);
+		verify(this.authenticationRequestRepository).saveAuthenticationRequest(
+				any(Saml2RedirectAuthenticationRequest.class), eq(this.request), eq(this.response));
+	}
+
+	@Test
+	public void doFilterWhenPostThenSaveRedirectRequest() throws ServletException, IOException {
+		RelyingPartyRegistration registration = this.rpBuilder
+				.assertingPartyDetails((asserting) -> asserting.singleSignOnServiceBinding(Saml2MessageBinding.POST))
+				.build();
+		Saml2AuthenticationRequestContext context = authenticationRequestContext()
+				.relyingPartyRegistration(registration).build();
+		Saml2PostAuthenticationRequest request = postAuthenticationRequest(context).build();
+		given(this.resolver.resolve(any())).willReturn(context);
+		given(this.factory.createPostAuthenticationRequest(any())).willReturn(request);
+		this.filter.doFilterInternal(this.request, this.response, this.filterChain);
+		verify(this.authenticationRequestRepository).saveAuthenticationRequest(
+				any(Saml2PostAuthenticationRequest.class), eq(this.request), eq(this.response));
+	}
+
 }
 }

+ 24 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java

@@ -31,10 +31,12 @@ import org.springframework.core.io.ClassPathResource;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
 import org.springframework.security.saml2.core.Saml2Utils;
 import org.springframework.security.saml2.core.Saml2Utils;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
 import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
+import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.util.StreamUtils;
 import org.springframework.util.StreamUtils;
 import org.springframework.web.util.UriUtils;
 import org.springframework.web.util.UriUtils;
 
 
@@ -43,6 +45,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
 
 
 @ExtendWith(MockitoExtension.class)
 @ExtendWith(MockitoExtension.class)
 public class Saml2AuthenticationTokenConverterTests {
 public class Saml2AuthenticationTokenConverterTests {
@@ -155,6 +158,27 @@ public class Saml2AuthenticationTokenConverterTests {
 		validateSsoCircleXml(token.getSaml2Response());
 		validateSsoCircleXml(token.getSaml2Response());
 	}
 	}
 
 
+	@Test
+	public void convertWhenSavedAuthenticationRequestThenToken() {
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
+				Saml2AuthenticationRequestRepository.class);
+		AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
+				this.relyingPartyRegistrationResolver);
+		converter.setAuthenticationRequestRepository(authenticationRequestRepository);
+		given(this.relyingPartyRegistrationResolver.convert(any(HttpServletRequest.class)))
+				.willReturn(this.relyingPartyRegistration);
+		given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class)))
+				.willReturn(authenticationRequest);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setParameter("SAMLResponse", Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
+		Saml2AuthenticationToken token = converter.convert(request);
+		assertThat(token.getSaml2Response()).isEqualTo("response");
+		assertThat(token.getRelyingPartyRegistration().getRegistrationId())
+				.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
+		assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
+	}
+
 	private void validateSsoCircleXml(String xml) {
 	private void validateSsoCircleXml(String xml) {
 		assertThat(xml).contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"")
 		assertThat(xml).contains("InResponseTo=\"ARQ9a73ead-7dcf-45a8-89eb-26f3c9900c36\"")
 				.contains(" ID=\"s246d157446618e90e43fb79bdd4d9e9e19cf2c7c4\"")
 				.contains(" ID=\"s246d157446618e90e43fb79bdd4d9e9e19cf2c7c4\"")