Browse Source

Add findUniqueByAssertingPartyEntityId

Closes gh-12848
Josh Cummings 2 years ago
parent
commit
97d1a49daf

+ 30 - 2
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepository.java

@@ -21,15 +21,19 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map;
 
 import org.springframework.util.Assert;
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
 
 /**
- * An in-memory implementation of {@link RelyingPartyRegistrationRepository}.
- * Also implements {@link Iterable} to simplify the default login page.
+ * An in-memory implementation of {@link RelyingPartyRegistrationRepository}. Also
+ * implements {@link Iterable} to simplify the default login page.
  *
  * @author Filip Hanik
+ * @author Josh Cummings
  * @since 5.2
  */
 public class InMemoryRelyingPartyRegistrationRepository
@@ -37,6 +41,8 @@ public class InMemoryRelyingPartyRegistrationRepository
 
 	private final Map<String, RelyingPartyRegistration> byRegistrationId;
 
+	private final Map<String, List<RelyingPartyRegistration>> byAssertingPartyEntityId;
+
 	public InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistration... registrations) {
 		this(Arrays.asList(registrations));
 	}
@@ -44,6 +50,7 @@ public class InMemoryRelyingPartyRegistrationRepository
 	public InMemoryRelyingPartyRegistrationRepository(Collection<RelyingPartyRegistration> registrations) {
 		Assert.notEmpty(registrations, "registrations cannot be empty");
 		this.byRegistrationId = createMappingToIdentityProvider(registrations);
+		this.byAssertingPartyEntityId = createMappingByAssertingPartyEntityId(registrations);
 	}
 
 	private static Map<String, RelyingPartyRegistration> createMappingToIdentityProvider(
@@ -59,11 +66,32 @@ public class InMemoryRelyingPartyRegistrationRepository
 		return Collections.unmodifiableMap(result);
 	}
 
+	private static Map<String, List<RelyingPartyRegistration>> createMappingByAssertingPartyEntityId(
+			Collection<RelyingPartyRegistration> rps) {
+		MultiValueMap<String, RelyingPartyRegistration> result = new LinkedMultiValueMap<>();
+		for (RelyingPartyRegistration rp : rps) {
+			result.add(rp.getAssertingPartyDetails().getEntityId(), rp);
+		}
+		return Collections.unmodifiableMap(result);
+	}
+
 	@Override
 	public RelyingPartyRegistration findByRegistrationId(String id) {
 		return this.byRegistrationId.get(id);
 	}
 
+	@Override
+	public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
+		Collection<RelyingPartyRegistration> registrations = this.byAssertingPartyEntityId.get(entityId);
+		if (registrations == null) {
+			return null;
+		}
+		if (registrations.size() > 1) {
+			return null;
+		}
+		return registrations.iterator().next();
+	}
+
 	@Override
 	public Iterator<RelyingPartyRegistration> iterator() {
 		return this.byRegistrationId.values().iterator();

+ 13 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationRepository.java

@@ -20,6 +20,7 @@ package org.springframework.security.saml2.provider.service.registration;
  * A repository for {@link RelyingPartyRegistration}s
  *
  * @author Filip Hanik
+ * @author Josh Cummings
  * @since 5.2
  */
 public interface RelyingPartyRegistrationRepository {
@@ -32,4 +33,16 @@ public interface RelyingPartyRegistrationRepository {
 	 */
 	RelyingPartyRegistration findByRegistrationId(String registrationId);
 
+	/**
+	 * Returns the unique relying party registration associated with the asserting party's
+	 * {@code entityId} or {@code null} if there is no unique match.
+	 * @param entityId the asserting party's entity id
+	 * @return the unique {@link RelyingPartyRegistration} associated the given asserting
+	 * party; {@code null} of there is no unique match asserting party
+	 * @since 6.1
+	 */
+	default RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
+		return findByRegistrationId(entityId);
+	}
+
 }

+ 194 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlAuthenticationTokenConverter.java

@@ -0,0 +1,194 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web;
+
+import java.io.ByteArrayOutputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.Base64;
+import java.util.function.Function;
+import java.util.zip.Inflater;
+import java.util.zip.InflaterOutputStream;
+
+import jakarta.servlet.http.HttpServletRequest;
+
+import org.springframework.http.HttpMethod;
+import org.springframework.security.saml2.core.Saml2Error;
+import org.springframework.security.saml2.core.Saml2ErrorCodes;
+import org.springframework.security.saml2.core.Saml2ParameterNames;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.util.Assert;
+
+/**
+ * An {@link AuthenticationConverter} that generates a {@link Saml2AuthenticationToken}
+ * appropriate for authenticated a SAML 2.0 Assertion against an
+ * {@link org.springframework.security.authentication.AuthenticationManager}.
+ *
+ * @author Josh Cummings
+ * @since 5.4
+ */
+public final class Saml2AuthenticationTokenConverter implements AuthenticationConverter {
+
+	// MimeDecoder allows extra line-breaks as well as other non-alphabet values.
+	// This matches the behaviour of the commons-codec decoder.
+	private static final Base64.Decoder BASE64 = Base64.getMimeDecoder();
+
+	private static final Base64Checker BASE_64_CHECKER = new Base64Checker();
+
+	private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
+
+	private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
+
+	/**
+	 * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
+	 * resolving {@link RelyingPartyRegistration}s
+	 * @param relyingPartyRegistrationResolver the strategy for resolving
+	 * {@link RelyingPartyRegistration}s
+	 */
+	public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
+		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
+		this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
+		this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
+	}
+
+	@Override
+	public Saml2AuthenticationToken convert(HttpServletRequest request) {
+		AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
+		String relyingPartyRegistrationId = (authenticationRequest != null)
+				? authenticationRequest.getRelyingPartyRegistrationId() : null;
+		RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
+				relyingPartyRegistrationId);
+		if (relyingPartyRegistration == null) {
+			return null;
+		}
+		String saml2Response = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
+		if (saml2Response == null) {
+			return null;
+		}
+		byte[] b = samlDecode(saml2Response);
+		saml2Response = inflateIfRequired(request, b);
+		return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest);
+	}
+
+	/**
+	 * Use the given {@link Saml2AuthenticationRequestRepository} to load authentication
+	 * request.
+	 * @param authenticationRequestRepository the
+	 * {@link Saml2AuthenticationRequestRepository} to use
+	 * @since 5.6
+	 */
+	public void setAuthenticationRequestRepository(
+			Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
+		Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
+		this.loader = authenticationRequestRepository::loadAuthenticationRequest;
+	}
+
+	private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
+		return this.loader.apply(request);
+	}
+
+	private String inflateIfRequired(HttpServletRequest request, byte[] b) {
+		if (HttpMethod.GET.matches(request.getMethod())) {
+			return samlInflate(b);
+		}
+		return new String(b, StandardCharsets.UTF_8);
+	}
+
+	private byte[] samlDecode(String base64EncodedPayload) {
+		try {
+			BASE_64_CHECKER.checkAcceptable(base64EncodedPayload);
+			return BASE64.decode(base64EncodedPayload);
+		}
+		catch (Exception ex) {
+			throw new Saml2AuthenticationException(
+					new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Failed to decode SAMLResponse"), ex);
+		}
+	}
+
+	private String samlInflate(byte[] b) {
+		try {
+			ByteArrayOutputStream out = new ByteArrayOutputStream();
+			InflaterOutputStream inflaterOutputStream = new InflaterOutputStream(out, new Inflater(true));
+			inflaterOutputStream.write(b);
+			inflaterOutputStream.finish();
+			return out.toString(StandardCharsets.UTF_8.name());
+		}
+		catch (Exception ex) {
+			throw new Saml2AuthenticationException(
+					new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, "Unable to inflate string"), ex);
+		}
+	}
+
+	static class Base64Checker {
+
+		private static final int[] values = genValueMapping();
+
+		Base64Checker() {
+
+		}
+
+		private static int[] genValueMapping() {
+			byte[] alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
+					.getBytes(StandardCharsets.ISO_8859_1);
+
+			int[] values = new int[256];
+			Arrays.fill(values, -1);
+			for (int i = 0; i < alphabet.length; i++) {
+				values[alphabet[i] & 0xff] = i;
+			}
+			return values;
+		}
+
+		boolean isAcceptable(String s) {
+			int goodChars = 0;
+			int lastGoodCharVal = -1;
+
+			// count number of characters from Base64 alphabet
+			for (int i = 0; i < s.length(); i++) {
+				int val = values[0xff & s.charAt(i)];
+				if (val != -1) {
+					lastGoodCharVal = val;
+					goodChars++;
+				}
+			}
+
+			// in cases of an incomplete final chunk, ensure the unused bits are zero
+			switch (goodChars % 4) {
+			case 0:
+				return true;
+			case 2:
+				return (lastGoodCharVal & 0b1111) == 0;
+			case 3:
+				return (lastGoodCharVal & 0b11) == 0;
+			default:
+				return false;
+			}
+		}
+
+		void checkAcceptable(String ins) {
+			if (!isAcceptable(ins)) {
+				throw new IllegalArgumentException("Unaccepted Encoding");
+			}
+		}
+
+	}
+
+}

+ 18 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/InMemoryRelyingPartyRegistrationRepositoryTests.java

@@ -42,4 +42,22 @@ public class InMemoryRelyingPartyRegistrationRepositoryTests {
 		assertThat(registrations.findByRegistrationId(null)).isNull();
 	}
 
+	@Test
+	void findByAssertingPartyEntityIdWhenGivenEntityIdThenReturnsMatchingRegistrations() {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
+		InMemoryRelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository(
+				registration);
+		String assertingPartyEntityId = registration.getAssertingPartyDetails().getEntityId();
+		assertThat(registrations.findUniqueByAssertingPartyEntityId(assertingPartyEntityId)).isEqualTo(registration);
+	}
+
+	@Test
+	void findByAssertingPartyEntityIdWhenGivenWrongEntityIdThenReturnsEmpty() {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
+		InMemoryRelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository(
+				registration);
+		String assertingPartyEntityId = registration.getAssertingPartyDetails().getEntityId();
+		assertThat(registrations.findUniqueByAssertingPartyEntityId(assertingPartyEntityId + "wrong")).isNull();
+	}
+
 }