瀏覽代碼

Polish DefaultSaml2AuthenticationRequestContextResolver

Issue gh-8360
Issue gh-8887
Josh Cummings 5 年之前
父節點
當前提交
a10c2c6cf8

+ 17 - 3
config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -35,6 +35,9 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
+import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
+import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
+import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
 import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@@ -317,15 +320,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
 
 	private final class AuthenticationRequestEndpointConfig {
 		private String filterProcessingUrl = "/saml2/authenticate/{registrationId}";
+
 		private AuthenticationRequestEndpointConfig() {
 		}
 
 		private Filter build(B http) {
 			Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
+			Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);
 
 			return postProcess(new Saml2WebSsoAuthenticationRequestFilter(
-							Saml2LoginConfigurer.this.relyingPartyRegistrationRepository,
-							authenticationRequestResolver));
+					contextResolver, authenticationRequestResolver));
 		}
 
 		private Saml2AuthenticationRequestFactory getResolver(B http) {
@@ -335,6 +339,16 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
 			}
 			return resolver;
 		}
+
+		private Saml2AuthenticationRequestContextResolver getContextResolver(B http) {
+			Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, Saml2AuthenticationRequestContextResolver.class);
+			if (resolver == null) {
+				return new DefaultSaml2AuthenticationRequestContextResolver(
+						new DefaultRelyingPartyRegistrationResolver(
+								Saml2LoginConfigurer.this.relyingPartyRegistrationRepository));
+			}
+			return resolver;
+		}
 	}
 
 }

+ 4 - 16
config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

@@ -65,10 +65,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
-import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
-import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.context.HttpRequestResponseHolder;
@@ -87,6 +85,7 @@ import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static org.springframework.security.config.Customizer.withDefaults;
 import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
 import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
@@ -161,11 +160,11 @@ public class Saml2LoginConfigurerTests {
 		Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
 		Saml2AuthenticationRequestContextResolver resolver =
 				CustomAuthenticationRequestContextResolver.resolver;
-		when(resolver.resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)))
+		when(resolver.resolve(any(HttpServletRequest.class)))
 				.thenReturn(context);
 		this.mvc.perform(get("/saml2/authenticate/registration-id"))
 				.andExpect(status().isFound());
-		verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
+		verify(resolver).resolve(any(HttpServletRequest.class));
 	}
 
 	@Test
@@ -276,22 +275,11 @@ public class Saml2LoginConfigurerTests {
 
 		@Override
 		protected void configure(HttpSecurity http) throws Exception {
-			ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter> processor
-					= new ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter>() {
-				@Override
-				public <O extends Saml2WebSsoAuthenticationRequestFilter> O postProcess(O filter) {
-					filter.setAuthenticationRequestContextResolver(resolver);
-					return filter;
-				}
-			};
-
 			http
 				.authorizeRequests(authz -> authz
 						.anyRequest().authenticated()
 				)
-				.saml2Login(saml2 -> saml2
-						.addObjectPostProcessor(processor)
-				);
+				.saml2Login(withDefaults());
 		}
 
 		@Bean

+ 13 - 23
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java

@@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2R
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
+import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@@ -69,9 +70,8 @@ import static java.nio.charset.StandardCharsets.ISO_8859_1;
  */
 public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter {
 
-	private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
+	private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver;
 	private Saml2AuthenticationRequestFactory authenticationRequestFactory;
-	private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver();
 
 	private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");
 
@@ -83,21 +83,24 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 	 */
 	@Deprecated
 	public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
-		this(relyingPartyRegistrationRepository,
+		this(new DefaultSaml2AuthenticationRequestContextResolver(
+				new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)),
 				new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory());
 	}
 
 	/**
 	 * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters
 	 *
-	 * @param relyingPartyRegistrationRepository a repository for relying party configurations
+	 * @param authenticationRequestContextResolver a strategy for formulating a {@link Saml2AuthenticationRequestContext}
 	 * @since 5.4
 	 */
-	public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
+	public Saml2WebSsoAuthenticationRequestFilter(
+			Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver,
 			Saml2AuthenticationRequestFactory authenticationRequestFactory) {
-		Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
+
+		Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
 		Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null");
-		this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
+		this.authenticationRequestContextResolver = authenticationRequestContextResolver;
 		this.authenticationRequestFactory = authenticationRequestFactory;
 	}
 
@@ -123,17 +126,6 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 		this.redirectMatcher = redirectMatcher;
 	}
 
-	/**
-	 * Use the given {@link Saml2AuthenticationRequestContextResolver} that creates a {@link Saml2AuthenticationRequestContext}
-	 *
-	 * @param authenticationRequestContextResolver the {@link Saml2AuthenticationRequestContextResolver} to use
-	 * @since 5.4
-	 */
-	public void setAuthenticationRequestContextResolver(Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver) {
-		Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
-		this.authenticationRequestContextResolver = authenticationRequestContextResolver;
-	}
-
 	/**
 	 * {@inheritDoc}
 	 */
@@ -147,14 +139,12 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 			return;
 		}
 
-		String registrationId = matcher.getVariables().get("registrationId");
-		RelyingPartyRegistration relyingParty =
-				this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
-		if (relyingParty == null) {
+		Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request);
+		if (context == null) {
 			response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
 			return;
 		}
-		Saml2AuthenticationRequestContext context = authenticationRequestContextResolver.resolve(request, relyingParty);
+		RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration();
 		if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
 			sendRedirect(response, context);
 		} else {

+ 15 - 63
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java

@@ -16,45 +16,45 @@
 
 package org.springframework.security.saml2.provider.service.web;
 
-import java.util.HashMap;
-import java.util.Map;
-import java.util.function.Function;
 import javax.servlet.http.HttpServletRequest;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.util.Assert;
-import org.springframework.util.StringUtils;
-import org.springframework.web.util.UriComponents;
-import org.springframework.web.util.UriComponentsBuilder;
-
-import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
-import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
 
 /**
  * The default implementation for {@link Saml2AuthenticationRequestContextResolver}
  * which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext}
  *
  * @author Shazin Sadakath
+ * @author Josh Cummings
  * @since 5.4
  */
 public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver {
 
 	private final Log logger = LogFactory.getLog(getClass());
 
-	private static final char PATH_DELIMITER = '/';
+	private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
+
+	public DefaultSaml2AuthenticationRequestContextResolver
+			(Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
+		this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
+	}
 
 	/**
 	 * {@inheritDoc}
 	 */
 	@Override
-	public Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
-			RelyingPartyRegistration relyingParty) {
+	public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) {
 		Assert.notNull(request, "request cannot be null");
-		Assert.notNull(relyingParty, "relyingParty cannot be null");
+		RelyingPartyRegistration relyingParty = this.relyingPartyRegistrationResolver.convert(request);
+		if (relyingParty == null) {
+			return null;
+		}
 		if (this.logger.isDebugEnabled()) {
 			this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" +
 					relyingParty.getRegistrationId() + "]");
@@ -65,59 +65,11 @@ public final class DefaultSaml2AuthenticationRequestContextResolver implements S
 	private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext(
 			HttpServletRequest request, RelyingPartyRegistration relyingParty) {
 
-		String applicationUri = getApplicationUri(request);
-		Function<String, String> resolver = templateResolver(applicationUri, relyingParty);
-		String localSpEntityId = resolver.apply(relyingParty.getEntityId());
-		String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceLocation());
 		return Saml2AuthenticationRequestContext.builder()
-				.issuer(localSpEntityId)
+				.issuer(relyingParty.getEntityId())
 				.relyingPartyRegistration(relyingParty)
-				.assertionConsumerServiceUrl(assertionConsumerServiceUrl)
+				.assertionConsumerServiceUrl(relyingParty.getAssertionConsumerServiceLocation())
 				.relayState(request.getParameter("RelayState"))
 				.build();
 	}
-
-	private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
-		return template -> resolveUrlTemplate(template, applicationUri, relyingParty);
-	}
-
-	private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
-		String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
-		String registrationId = relyingParty.getRegistrationId();
-		Map<String, String> uriVariables = new HashMap<>();
-		UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl)
-				.replaceQuery(null)
-				.fragment(null)
-				.build();
-		String scheme = uriComponents.getScheme();
-		uriVariables.put("baseScheme", scheme == null ? "" : scheme);
-		String host = uriComponents.getHost();
-		uriVariables.put("baseHost", host == null ? "" : host);
-		// following logic is based on HierarchicalUriComponents#toUriString()
-		int port = uriComponents.getPort();
-		uriVariables.put("basePort", port == -1 ? "" : ":" + port);
-		String path = uriComponents.getPath();
-		if (StringUtils.hasLength(path)) {
-			if (path.charAt(0) != PATH_DELIMITER) {
-				path = PATH_DELIMITER + path;
-			}
-		}
-		uriVariables.put("basePath", path == null ? "" : path);
-		uriVariables.put("baseUrl", uriComponents.toUriString());
-		uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
-		uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
-
-		return UriComponentsBuilder.fromUriString(template)
-				.buildAndExpand(uriVariables)
-				.toUriString();
-	}
-
-	private static String getApplicationUri(HttpServletRequest request) {
-		UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request))
-				.replacePath(request.getContextPath())
-				.replaceQuery(null)
-				.fragment(null)
-				.build();
-		return uriComponents.toUriString();
-	}
 }

+ 5 - 7
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java

@@ -16,16 +16,16 @@
 
 package org.springframework.security.saml2.provider.service.web;
 
-import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
-import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
-
 import javax.servlet.http.HttpServletRequest;
 
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
+
 /**
  * This {@code Saml2AuthenticationRequestContextResolver} formulates a
  * <a href="https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf">SAML 2.0 AuthnRequest</a> (line 1968)
  *
  * @author Shazin Sadakath
+ * @author Josh Cummings
  * @since 5.4
  */
 public interface Saml2AuthenticationRequestContextResolver {
@@ -35,9 +35,7 @@ public interface Saml2AuthenticationRequestContextResolver {
 	 *
 	 *
 	 * @param request the current request
-	 * @param relyingParty the relying party responsible for saml2 sso authentication
-	 * @return the created {@link Saml2AuthenticationRequestContext} for request/relying party combination
+	 * @return the created {@link Saml2AuthenticationRequestContext} for the request
 	 */
-	Saml2AuthenticationRequestContext resolve(HttpServletRequest request,
-			RelyingPartyRegistration relyingParty);
+	Saml2AuthenticationRequestContext resolve(HttpServletRequest request);
 }

+ 8 - 2
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java

@@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
 import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.web.util.HtmlUtils;
 import org.springframework.web.util.UriUtils;
 
@@ -41,6 +42,7 @@ import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
+import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
 import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
 
 public class Saml2WebSsoAuthenticationRequestFilterTests {
@@ -49,6 +51,8 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
 	private Saml2WebSsoAuthenticationRequestFilter filter;
 	private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class);
 	private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class);
+	private Saml2AuthenticationRequestContextResolver resolver =
+			mock(Saml2AuthenticationRequestContextResolver.class);
 	private MockHttpServletRequest request;
 	private MockHttpServletResponse response;
 	private MockFilterChain filterChain;
@@ -188,12 +192,14 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
 		when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri");
 		when(authenticationRequest.getRelayState()).thenReturn("relay");
 		when(authenticationRequest.getSamlRequest()).thenReturn("saml");
-		when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty);
+		when(this.resolver.resolve(this.request)).thenReturn(authenticationRequestContext()
+				.relyingPartyRegistration(relyingParty)
+				.build());
 		when(this.factory.createPostAuthenticationRequest(any()))
 				.thenReturn(authenticationRequest);
 
 		Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
-				(this.repository, this.factory);
+				(this.resolver, this.factory);
 		filter.doFilterInternal(this.request, this.response, this.filterChain);
 		assertThat(this.response.getContentAsString())
 				.contains("<form action=\"uri\" method=\"post\">")

+ 13 - 20
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java

@@ -44,11 +44,13 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests {
 	private MockHttpServletRequest request;
 	private RelyingPartyRegistration.Builder relyingPartyBuilder;
 	private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver
-			= new DefaultSaml2AuthenticationRequestContextResolver();
+			= new DefaultSaml2AuthenticationRequestContextResolver(
+					new DefaultRelyingPartyRegistrationResolver(id -> relyingPartyBuilder.build()));
 
 	@Before
 	public void setup() {
 		this.request = new MockHttpServletRequest();
+		this.request.setPathInfo("/saml2/authenticate/registration-id");
 		this.relyingPartyBuilder = RelyingPartyRegistration
 				.withRegistrationId(REGISTRATION_ID)
 				.localEntityIdTemplate(RELYING_PARTY_ENTITY_ID)
@@ -61,52 +63,43 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests {
 	@Test
 	public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() {
 		this.request.addParameter("RelayState", "relay-state");
-		RelyingPartyRegistration relyingParty = this.relyingPartyBuilder.build();
 		Saml2AuthenticationRequestContext context =
-				this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
+				this.authenticationRequestContextResolver.resolve(this.request);
 
 		assertThat(context).isNotNull();
 		assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL);
 		assertThat(context.getRelayState()).isEqualTo("relay-state");
 		assertThat(context.getDestination()).isEqualTo(ASSERTING_PARTY_SSO_URL);
 		assertThat(context.getIssuer()).isEqualTo(RELYING_PARTY_ENTITY_ID);
-		assertThat(context.getRelyingPartyRegistration()).isSameAs(relyingParty);
+		assertThat(context.getRelyingPartyRegistration().getRegistrationId())
+				.isSameAs(this.relyingPartyBuilder.build().getRegistrationId());
 	}
 
 	@Test
 	public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() {
-		RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
-				.assertionConsumerServiceUrlTemplate("/saml2/authenticate/{registrationId}")
-				.build();
+		this.relyingPartyBuilder
+				.assertionConsumerServiceLocation("/saml2/authenticate/{registrationId}");
 		Saml2AuthenticationRequestContext context =
-				this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
+				this.authenticationRequestContextResolver.resolve(this.request);
 
 		assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id");
 	}
 
 	@Test
 	public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() {
-		RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
-				.assertionConsumerServiceUrlTemplate("{baseUrl}/saml2/authenticate/{registrationId}")
-				.build();
+		this.relyingPartyBuilder
+				.assertionConsumerServiceLocation("{baseUrl}/saml2/authenticate/{registrationId}");
 		Saml2AuthenticationRequestContext context =
-				this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
+				this.authenticationRequestContextResolver.resolve(this.request);
 
 		assertThat(context.getAssertionConsumerServiceUrl())
 				.isEqualTo("http://localhost/saml2/authenticate/registration-id");
 	}
 
-	@Test
-	public void resolveWhenRequestNullThenException() {
-		assertThatCode(() ->
-				this.authenticationRequestContextResolver.resolve(this.request, null))
-						.isInstanceOf(IllegalArgumentException.class);
-	}
-
 	@Test
 	public void resolveWhenRelyingPartyNullThenException() {
 		assertThatCode(() ->
-				this.authenticationRequestContextResolver.resolve(null, this.relyingPartyBuilder.build()))
+				this.authenticationRequestContextResolver.resolve(null))
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 }