Bläddra i källkod

Polish DefaultSaml2AuthenticationRequestContextResolver

- Added more tests
- Standardized terminology

Issue gh-8360
Josh Cummings 5 år sedan
förälder
incheckning
ab772893c7

+ 6 - 9
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java

@@ -16,8 +16,14 @@
 
 package org.springframework.security.saml2.provider.service.web;
 
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Function;
+import javax.servlet.http.HttpServletRequest;
+
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.util.Assert;
@@ -25,11 +31,6 @@ import org.springframework.util.StringUtils;
 import org.springframework.web.util.UriComponents;
 import org.springframework.web.util.UriComponentsBuilder;
 
-import javax.servlet.http.HttpServletRequest;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.function.Function;
-
 import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
 import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
 
@@ -81,10 +82,6 @@ public final class DefaultSaml2AuthenticationRequestContextResolver implements S
 	}
 
 	private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
-		if (!StringUtils.hasText(template)) {
-			return baseUrl;
-		}
-
 		String entityId = relyingParty.getProviderDetails().getEntityId();
 		String registrationId = relyingParty.getRegistrationId();
 		Map<String, String> uriVariables = new HashMap<>();

+ 67 - 22
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java

@@ -23,44 +23,89 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 
 import static org.springframework.security.saml2.provider.service.servlet.filter.TestSaml2SigningCredentials.signingCredential;
-import static org.assertj.core.api.Assertions.*;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
 
+/**
+ * Tests for {@link DefaultSaml2AuthenticationRequestContextResolver}
+ *
+ * @author Shazin Sadakath
+ * @author Josh Cummings
+ */
 public class DefaultSaml2AuthenticationRequestContextResolverTests {
 
-	private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO";
-	private static final String TEMPLATE = "template";
+	private static final String ASSERTING_PARTY_SSO_URL = "https://idp.example.com/sso";
+	private static final String RELYING_PARTY_SSO_URL = "https://sp.example.com/sso";
+	private static final String ASSERTING_PARTY_ENTITY_ID = "asserting-party-entity-id";
+	private static final String RELYING_PARTY_ENTITY_ID = "relying-party-entity-id";
 	private static final String REGISTRATION_ID = "registration-id";
-	private static final String IDP_ENTITY_ID = "idp-entity-id";
 
 	private MockHttpServletRequest request;
-	private RelyingPartyRegistration.Builder rpBuilder;
-	private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver();
+	private RelyingPartyRegistration.Builder relyingPartyBuilder;
+	private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver
+			= new DefaultSaml2AuthenticationRequestContextResolver();
 
 	@Before
 	public void setup() {
-		request = new MockHttpServletRequest();
-		rpBuilder = RelyingPartyRegistration
+		this.request = new MockHttpServletRequest();
+		this.relyingPartyBuilder = RelyingPartyRegistration
 				.withRegistrationId(REGISTRATION_ID)
-				.providerDetails(c -> c.entityId(IDP_ENTITY_ID))
-				.providerDetails(c -> c.webSsoUrl(IDP_SSO_URL))
-				.assertionConsumerServiceUrlTemplate(TEMPLATE)
+				.localEntityIdTemplate(RELYING_PARTY_ENTITY_ID)
+				.providerDetails(c -> c.entityId(ASSERTING_PARTY_ENTITY_ID))
+				.providerDetails(c -> c.webSsoUrl(ASSERTING_PARTY_SSO_URL))
+				.assertionConsumerServiceUrlTemplate(RELYING_PARTY_SSO_URL)
 				.credentials(c -> c.add(signingCredential()));
 	}
 
 	@Test
-	public void resoleWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() {
-		Saml2AuthenticationRequestContext authenticationRequestContext = authenticationRequestContextResolver.resolve(request, rpBuilder.build());
+	public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() {
+		this.request.addParameter("RelayState", "relay-state");
+		RelyingPartyRegistration relyingParty = this.relyingPartyBuilder.build();
+		Saml2AuthenticationRequestContext context =
+				this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
+
+		assertThat(context).isNotNull();
+		assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL);
+		assertThat(context.getRelayState()).isEqualTo("relay-state");
+		assertThat(context.getDestination()).isEqualTo(ASSERTING_PARTY_SSO_URL);
+		assertThat(context.getIssuer()).isEqualTo(RELYING_PARTY_ENTITY_ID);
+		assertThat(context.getRelyingPartyRegistration()).isSameAs(relyingParty);
+	}
+
+	@Test
+	public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() {
+		RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
+				.assertionConsumerServiceUrlTemplate("/saml2/authenticate/{registrationId}")
+				.build();
+		Saml2AuthenticationRequestContext context =
+				this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
+
+		assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id");
+	}
+
+	@Test
+	public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() {
+		RelyingPartyRegistration relyingParty = this.relyingPartyBuilder
+				.assertionConsumerServiceUrlTemplate("{baseUrl}/saml2/authenticate/{registrationId}")
+				.build();
+		Saml2AuthenticationRequestContext context =
+				this.authenticationRequestContextResolver.resolve(this.request, relyingParty);
 
-		assertThat(authenticationRequestContext).isNotNull();
-		assertThat(authenticationRequestContext.getAssertionConsumerServiceUrl()).isEqualTo(TEMPLATE);
-		assertThat(authenticationRequestContext.getRelyingPartyRegistration().getRegistrationId()).isEqualTo(REGISTRATION_ID);
-		assertThat(authenticationRequestContext.getRelyingPartyRegistration().getProviderDetails().getEntityId()).isEqualTo(IDP_ENTITY_ID);
-		assertThat(authenticationRequestContext.getRelyingPartyRegistration().getProviderDetails().getWebSsoUrl()).isEqualTo(IDP_SSO_URL);
-		assertThat(authenticationRequestContext.getRelyingPartyRegistration().getCredentials()).isNotEmpty();
+		assertThat(context.getAssertionConsumerServiceUrl())
+				.isEqualTo("http://localhost/saml2/authenticate/registration-id");
 	}
 
-	@Test(expected = IllegalArgumentException.class)
-	public void resolveWhenRequestAndRelyingPartyNullThenException() {
-		authenticationRequestContextResolver.resolve(null, null);
+	@Test
+	public void resolveWhenRequestNullThenException() {
+		assertThatCode(() ->
+				this.authenticationRequestContextResolver.resolve(this.request, null))
+						.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void resolveWhenRelyingPartyNullThenException() {
+		assertThatCode(() ->
+				this.authenticationRequestContextResolver.resolve(null, this.relyingPartyBuilder.build()))
+				.isInstanceOf(IllegalArgumentException.class);
 	}
 }