|
@@ -19,6 +19,7 @@ package org.springframework.security.saml2.provider.service.authentication;
|
|
import java.io.ByteArrayOutputStream;
|
|
import java.io.ByteArrayOutputStream;
|
|
import java.io.IOException;
|
|
import java.io.IOException;
|
|
import java.io.ObjectOutputStream;
|
|
import java.io.ObjectOutputStream;
|
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
import java.time.Duration;
|
|
import java.time.Duration;
|
|
import java.time.Instant;
|
|
import java.time.Instant;
|
|
import java.util.Arrays;
|
|
import java.util.Arrays;
|
|
@@ -46,6 +47,7 @@ import org.opensaml.saml.saml2.core.Assertion;
|
|
import org.opensaml.saml.saml2.core.Attribute;
|
|
import org.opensaml.saml.saml2.core.Attribute;
|
|
import org.opensaml.saml.saml2.core.AttributeStatement;
|
|
import org.opensaml.saml.saml2.core.AttributeStatement;
|
|
import org.opensaml.saml.saml2.core.AttributeValue;
|
|
import org.opensaml.saml.saml2.core.AttributeValue;
|
|
|
|
+import org.opensaml.saml.saml2.core.AuthnRequest;
|
|
import org.opensaml.saml.saml2.core.Conditions;
|
|
import org.opensaml.saml.saml2.core.Conditions;
|
|
import org.opensaml.saml.saml2.core.EncryptedAssertion;
|
|
import org.opensaml.saml.saml2.core.EncryptedAssertion;
|
|
import org.opensaml.saml.saml2.core.EncryptedAttribute;
|
|
import org.opensaml.saml.saml2.core.EncryptedAttribute;
|
|
@@ -74,6 +76,7 @@ import org.springframework.security.saml2.core.TestSaml2X509Credentials;
|
|
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken;
|
|
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken;
|
|
import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSamlObjects.CustomOpenSamlObject;
|
|
import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSamlObjects.CustomOpenSamlObject;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
|
|
+import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
|
|
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
|
|
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
|
|
import org.springframework.util.StringUtils;
|
|
import org.springframework.util.StringUtils;
|
|
|
|
|
|
@@ -217,6 +220,111 @@ public class OpenSaml4AuthenticationProviderTests {
|
|
this.provider.authenticate(token);
|
|
this.provider.authenticate(token);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @Test
|
|
|
|
+ public void evaluateInResponseToSucceedsWhenInResponseToInResponseAndAssertionsMatchRequestID() {
|
|
|
|
+ Response response = response();
|
|
|
|
+ response.setInResponseTo("SAML2");
|
|
|
|
+ response.getAssertions().add(signed(assertion("SAML2")));
|
|
|
|
+ response.getAssertions().add(signed(assertion("SAML2")));
|
|
|
|
+ AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
|
|
|
|
+ Saml2MessageBinding.POST, false);
|
|
|
|
+ Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
|
|
|
|
+ this.provider.authenticate(token);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void evaluateInResponseToSucceedsWhenInResponseToInAssertionOnlyMatchRequestID() {
|
|
|
|
+ Response response = response();
|
|
|
|
+ response.getAssertions().add(signed(assertion()));
|
|
|
|
+ response.getAssertions().add(signed(assertion("SAML2")));
|
|
|
|
+ AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
|
|
|
|
+ Saml2MessageBinding.POST, false);
|
|
|
|
+ Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
|
|
|
|
+ this.provider.authenticate(token);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndCorruptedStoredRequest() {
|
|
|
|
+ Response response = response();
|
|
|
|
+ response.getAssertions().add(signed(assertion()));
|
|
|
|
+ response.getAssertions().add(signed(assertion("SAML2")));
|
|
|
|
+ AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
|
|
|
|
+ Saml2MessageBinding.POST, true);
|
|
|
|
+ Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
|
|
|
|
+ assertThatExceptionOfType(Saml2AuthenticationException.class)
|
|
|
|
+ .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void evaluateInResponseToFailsWhenInResponseToInAssertionMismatchWithRequestID() {
|
|
|
|
+ Response response = response();
|
|
|
|
+ response.setInResponseTo("SAML2");
|
|
|
|
+ response.getAssertions().add(signed(assertion("SAML2")));
|
|
|
|
+ response.getAssertions().add(signed(assertion("BAD")));
|
|
|
|
+ AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
|
|
|
|
+ Saml2MessageBinding.POST, false);
|
|
|
|
+ Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
|
|
|
|
+ assertThatExceptionOfType(Saml2AuthenticationException.class)
|
|
|
|
+ .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndMismatchWithRequestID() {
|
|
|
|
+ Response response = response();
|
|
|
|
+ response.getAssertions().add(signed(assertion()));
|
|
|
|
+ response.getAssertions().add(signed(assertion("BAD")));
|
|
|
|
+ AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
|
|
|
|
+ Saml2MessageBinding.POST, false);
|
|
|
|
+ Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
|
|
|
|
+ assertThatExceptionOfType(Saml2AuthenticationException.class)
|
|
|
|
+ .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void evaluateInResponseToFailsWhenInResponseInToResponseMismatchWithRequestID() {
|
|
|
|
+ Response response = response();
|
|
|
|
+ response.setInResponseTo("BAD");
|
|
|
|
+ response.getAssertions().add(signed(assertion("SAML2")));
|
|
|
|
+ response.getAssertions().add(signed(assertion("SAML2")));
|
|
|
|
+ AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
|
|
|
|
+ Saml2MessageBinding.POST, false);
|
|
|
|
+ Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
|
|
|
|
+ assertThatExceptionOfType(Saml2AuthenticationException.class)
|
|
|
|
+ .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_in_response_to");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void evaluateInResponseToFailsWhenInResponseInToResponseAndCorruptedStoredRequest() {
|
|
|
|
+ Response response = response();
|
|
|
|
+ response.setInResponseTo("SAML2");
|
|
|
|
+ response.getAssertions().add(signed(assertion()));
|
|
|
|
+ response.getAssertions().add(signed(assertion()));
|
|
|
|
+ AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
|
|
|
|
+ Saml2MessageBinding.POST, true);
|
|
|
|
+ Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
|
|
|
|
+ assertThatExceptionOfType(Saml2AuthenticationException.class)
|
|
|
|
+ .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void evaluateInResponseToFailsWhenInResponseToInResponseButNoSavedRequest() {
|
|
|
|
+ Response response = response();
|
|
|
|
+ response.setInResponseTo("BAD");
|
|
|
|
+ Saml2AuthenticationToken token = token(response, verifying(registration()));
|
|
|
|
+ assertThatExceptionOfType(Saml2AuthenticationException.class)
|
|
|
|
+ .isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_in_response_to");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void evaluateInResponseToSucceedsWhenNoInResponseToInResponseOrAssertions() {
|
|
|
|
+ Response response = response();
|
|
|
|
+ response.getAssertions().add(signed(assertion()));
|
|
|
|
+ AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
|
|
|
|
+ Saml2MessageBinding.POST, false);
|
|
|
|
+ Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
|
|
|
|
+ this.provider.authenticate(token);
|
|
|
|
+ }
|
|
|
|
+
|
|
@Test
|
|
@Test
|
|
public void authenticateWhenAssertionContainsAttributesThenItSucceeds() {
|
|
public void authenticateWhenAssertionContainsAttributesThenItSucceeds() {
|
|
Response response = response();
|
|
Response response = response();
|
|
@@ -658,13 +766,27 @@ public class OpenSaml4AuthenticationProviderTests {
|
|
return response;
|
|
return response;
|
|
}
|
|
}
|
|
|
|
|
|
- private Assertion assertion() {
|
|
|
|
|
|
+ private AuthnRequest request() {
|
|
|
|
+ AuthnRequest request = TestOpenSamlObjects.authnRequest();
|
|
|
|
+ return request;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private String serializedRequest(AuthnRequest request, Saml2MessageBinding binding) {
|
|
|
|
+ String xml = serialize(request);
|
|
|
|
+ return (binding == Saml2MessageBinding.POST) ? Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8))
|
|
|
|
+ : Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private Assertion assertion(String inResponseTo) {
|
|
Assertion assertion = TestOpenSamlObjects.assertion();
|
|
Assertion assertion = TestOpenSamlObjects.assertion();
|
|
assertion.setIssueInstant(Instant.now());
|
|
assertion.setIssueInstant(Instant.now());
|
|
for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) {
|
|
for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) {
|
|
SubjectConfirmationData data = confirmation.getSubjectConfirmationData();
|
|
SubjectConfirmationData data = confirmation.getSubjectConfirmationData();
|
|
data.setNotBefore(Instant.now().minus(Duration.ofMillis(5 * 60 * 1000)));
|
|
data.setNotBefore(Instant.now().minus(Duration.ofMillis(5 * 60 * 1000)));
|
|
data.setNotOnOrAfter(Instant.now().plus(Duration.ofMillis(5 * 60 * 1000)));
|
|
data.setNotOnOrAfter(Instant.now().plus(Duration.ofMillis(5 * 60 * 1000)));
|
|
|
|
+ if (StringUtils.hasText(inResponseTo)) {
|
|
|
|
+ data.setInResponseTo(inResponseTo);
|
|
|
|
+ }
|
|
}
|
|
}
|
|
Conditions conditions = assertion.getConditions();
|
|
Conditions conditions = assertion.getConditions();
|
|
conditions.setNotBefore(Instant.now().minus(Duration.ofMillis(5 * 60 * 1000)));
|
|
conditions.setNotBefore(Instant.now().minus(Duration.ofMillis(5 * 60 * 1000)));
|
|
@@ -672,6 +794,10 @@ public class OpenSaml4AuthenticationProviderTests {
|
|
return assertion;
|
|
return assertion;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ private Assertion assertion() {
|
|
|
|
+ return assertion(null);
|
|
|
|
+ }
|
|
|
|
+
|
|
private <T extends SignableSAMLObject> T signed(T toSign) {
|
|
private <T extends SignableSAMLObject> T signed(T toSign) {
|
|
TestOpenSamlObjects.signed(toSign, TestSaml2X509Credentials.assertingPartySigningCredential(),
|
|
TestOpenSamlObjects.signed(toSign, TestSaml2X509Credentials.assertingPartySigningCredential(),
|
|
RELYING_PARTY_ENTITY_ID);
|
|
RELYING_PARTY_ENTITY_ID);
|
|
@@ -701,6 +827,27 @@ public class OpenSaml4AuthenticationProviderTests {
|
|
return new Saml2AuthenticationToken(registration.build(), serialize(response));
|
|
return new Saml2AuthenticationToken(registration.build(), serialize(response));
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ private Saml2AuthenticationToken token(Response response, RelyingPartyRegistration.Builder registration,
|
|
|
|
+ AbstractSaml2AuthenticationRequest authenticationRequest) {
|
|
|
|
+ return new Saml2AuthenticationToken(registration.build(), serialize(response), authenticationRequest);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest(String requestId,
|
|
|
|
+ Saml2MessageBinding binding, boolean corruptRequestString) {
|
|
|
|
+ AuthnRequest request = request();
|
|
|
|
+ if (requestId != null) {
|
|
|
|
+ request.setID(requestId);
|
|
|
|
+ }
|
|
|
|
+ String serializedRequest = serializedRequest(request, binding);
|
|
|
|
+ if (corruptRequestString) {
|
|
|
|
+ serializedRequest = serializedRequest.substring(2, serializedRequest.length() - 2);
|
|
|
|
+ }
|
|
|
|
+ AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
|
|
|
|
+ given(mockAuthenticationRequest.getSamlRequest()).willReturn(serializedRequest);
|
|
|
|
+ given(mockAuthenticationRequest.getBinding()).willReturn(binding);
|
|
|
|
+ return mockAuthenticationRequest;
|
|
|
|
+ }
|
|
|
|
+
|
|
private RelyingPartyRegistration.Builder registration() {
|
|
private RelyingPartyRegistration.Builder registration() {
|
|
return TestRelyingPartyRegistrations.noCredentials().entityId(RELYING_PARTY_ENTITY_ID)
|
|
return TestRelyingPartyRegistrations.noCredentials().entityId(RELYING_PARTY_ENTITY_ID)
|
|
.assertionConsumerServiceLocation(DESTINATION)
|
|
.assertionConsumerServiceLocation(DESTINATION)
|