|
@@ -1,5 +1,5 @@
|
|
/*
|
|
/*
|
|
- * Copyright 2002-2022 the original author or authors.
|
|
|
|
|
|
+ * Copyright 2002-2023 the original author or authors.
|
|
*
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -16,6 +16,7 @@
|
|
|
|
|
|
package org.springframework.security.saml2.provider.service.web;
|
|
package org.springframework.security.saml2.provider.service.web;
|
|
|
|
|
|
|
|
+import java.io.ByteArrayInputStream;
|
|
import java.io.ByteArrayOutputStream;
|
|
import java.io.ByteArrayOutputStream;
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.nio.charset.StandardCharsets;
|
|
import java.util.Arrays;
|
|
import java.util.Arrays;
|
|
@@ -25,8 +26,18 @@ import java.util.zip.Inflater;
|
|
import java.util.zip.InflaterOutputStream;
|
|
import java.util.zip.InflaterOutputStream;
|
|
|
|
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
|
|
+import net.shibboleth.utilities.java.support.xml.ParserPool;
|
|
|
|
+import org.opensaml.core.config.ConfigurationService;
|
|
|
|
+import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
|
|
|
|
+import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
|
|
|
|
+import org.opensaml.saml.saml2.core.Response;
|
|
|
|
+import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller;
|
|
|
|
+import org.w3c.dom.Document;
|
|
|
|
+import org.w3c.dom.Element;
|
|
|
|
|
|
import org.springframework.http.HttpMethod;
|
|
import org.springframework.http.HttpMethod;
|
|
|
|
+import org.springframework.security.saml2.Saml2Exception;
|
|
|
|
+import org.springframework.security.saml2.core.OpenSamlInitializationService;
|
|
import org.springframework.security.saml2.core.Saml2Error;
|
|
import org.springframework.security.saml2.core.Saml2Error;
|
|
import org.springframework.security.saml2.core.Saml2ErrorCodes;
|
|
import org.springframework.security.saml2.core.Saml2ErrorCodes;
|
|
import org.springframework.security.saml2.core.Saml2ParameterNames;
|
|
import org.springframework.security.saml2.core.Saml2ParameterNames;
|
|
@@ -34,7 +45,12 @@ import org.springframework.security.saml2.provider.service.authentication.Abstra
|
|
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
|
|
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
|
|
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
|
|
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
|
|
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
|
|
|
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
|
|
import org.springframework.security.web.authentication.AuthenticationConverter;
|
|
import org.springframework.security.web.authentication.AuthenticationConverter;
|
|
|
|
+import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
|
|
|
+import org.springframework.security.web.util.matcher.OrRequestMatcher;
|
|
|
|
+import org.springframework.security.web.util.matcher.RequestMatcher;
|
|
import org.springframework.util.Assert;
|
|
import org.springframework.util.Assert;
|
|
|
|
|
|
/**
|
|
/**
|
|
@@ -43,9 +59,13 @@ import org.springframework.util.Assert;
|
|
* {@link org.springframework.security.authentication.AuthenticationManager}.
|
|
* {@link org.springframework.security.authentication.AuthenticationManager}.
|
|
*
|
|
*
|
|
* @author Josh Cummings
|
|
* @author Josh Cummings
|
|
- * @since 5.4
|
|
|
|
|
|
+ * @since 6.1
|
|
*/
|
|
*/
|
|
-public final class Saml2AuthenticationTokenConverter implements AuthenticationConverter {
|
|
|
|
|
|
+public final class OpenSamlAuthenticationTokenConverter implements AuthenticationConverter {
|
|
|
|
+
|
|
|
|
+ static {
|
|
|
|
+ OpenSamlInitializationService.initialize();
|
|
|
|
+ }
|
|
|
|
|
|
// MimeDecoder allows extra line-breaks as well as other non-alphabet values.
|
|
// MimeDecoder allows extra line-breaks as well as other non-alphabet values.
|
|
// This matches the behaviour of the commons-codec decoder.
|
|
// This matches the behaviour of the commons-codec decoder.
|
|
@@ -53,39 +73,120 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
|
|
|
|
|
|
private static final Base64Checker BASE_64_CHECKER = new Base64Checker();
|
|
private static final Base64Checker BASE_64_CHECKER = new Base64Checker();
|
|
|
|
|
|
- private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
|
|
|
|
|
|
+ private final RelyingPartyRegistrationRepository registrations;
|
|
|
|
+
|
|
|
|
+ private RequestMatcher requestMatcher = new OrRequestMatcher(
|
|
|
|
+ new AntPathRequestMatcher("/login/saml2/sso/{registrationId}"),
|
|
|
|
+ new AntPathRequestMatcher("/login/saml2/sso"));
|
|
|
|
+
|
|
|
|
+ private final ParserPool parserPool;
|
|
|
|
+
|
|
|
|
+ private final ResponseUnmarshaller unmarshaller;
|
|
|
|
|
|
private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
|
|
private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
|
|
|
|
|
|
/**
|
|
/**
|
|
- * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
|
|
|
|
- * resolving {@link RelyingPartyRegistration}s
|
|
|
|
- * @param relyingPartyRegistrationResolver the strategy for resolving
|
|
|
|
|
|
+ * Constructs a {@link OpenSamlAuthenticationTokenConverter} given a repository for
|
|
|
|
+ * {@link RelyingPartyRegistration}s
|
|
|
|
+ * @param registrations the repository for {@link RelyingPartyRegistration}s
|
|
* {@link RelyingPartyRegistration}s
|
|
* {@link RelyingPartyRegistration}s
|
|
*/
|
|
*/
|
|
- public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
|
|
|
|
- Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
|
|
|
|
- this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
|
|
|
|
|
|
+ public OpenSamlAuthenticationTokenConverter(RelyingPartyRegistrationRepository registrations) {
|
|
|
|
+ Assert.notNull(registrations, "relyingPartyRegistrationRepository cannot be null");
|
|
|
|
+ XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
|
|
|
|
+ this.parserPool = registry.getParserPool();
|
|
|
|
+ this.unmarshaller = (ResponseUnmarshaller) XMLObjectProviderRegistrySupport.getUnmarshallerFactory()
|
|
|
|
+ .getUnmarshaller(Response.DEFAULT_ELEMENT_NAME);
|
|
|
|
+ this.registrations = registrations;
|
|
this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
|
|
this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ /**
|
|
|
|
+ * Resolve an authentication request from the given {@link HttpServletRequest}.
|
|
|
|
+ *
|
|
|
|
+ * <p>
|
|
|
|
+ * First uses the configured {@link RequestMatcher} to deduce whether an
|
|
|
|
+ * authentication request is being made and optionally for which
|
|
|
|
+ * {@code registrationId}.
|
|
|
|
+ *
|
|
|
|
+ * <p>
|
|
|
|
+ * If there is an associated {@code <saml2:AuthnRequest>}, then the
|
|
|
|
+ * {@code registrationId} is looked up and used.
|
|
|
|
+ *
|
|
|
|
+ * <p>
|
|
|
|
+ * If a {@code registrationId} is found in the request, then it is looked up and used.
|
|
|
|
+ * In that case, if none is found a {@link Saml2AuthenticationException} is thrown.
|
|
|
|
+ *
|
|
|
|
+ * <p>
|
|
|
|
+ * Finally, if no {@code registrationId} is found in the request, then the code
|
|
|
|
+ * attempts to resolve the {@link RelyingPartyRegistration} from the SAML Response's
|
|
|
|
+ * Issuer.
|
|
|
|
+ * @param request the HTTP request
|
|
|
|
+ * @return the {@link Saml2AuthenticationToken} authentication request
|
|
|
|
+ * @throws Saml2AuthenticationException if the {@link RequestMatcher} specifies a
|
|
|
|
+ * non-existent {@code registrationId}
|
|
|
|
+ */
|
|
@Override
|
|
@Override
|
|
public Saml2AuthenticationToken convert(HttpServletRequest request) {
|
|
public Saml2AuthenticationToken convert(HttpServletRequest request) {
|
|
|
|
+ String serialized = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
|
|
|
|
+ if (serialized == null) {
|
|
|
|
+ return null;
|
|
|
|
+ }
|
|
|
|
+ RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
|
|
|
|
+ if (!result.isMatch()) {
|
|
|
|
+ return null;
|
|
|
|
+ }
|
|
|
|
+ Saml2AuthenticationToken token = tokenByAuthenticationRequest(request);
|
|
|
|
+ if (token == null) {
|
|
|
|
+ token = tokenByRegistrationId(request, result);
|
|
|
|
+ }
|
|
|
|
+ if (token == null) {
|
|
|
|
+ token = tokenByEntityId(request);
|
|
|
|
+ }
|
|
|
|
+ return token;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private Saml2AuthenticationToken tokenByAuthenticationRequest(HttpServletRequest request) {
|
|
AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
|
|
AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
|
|
- String relyingPartyRegistrationId = (authenticationRequest != null)
|
|
|
|
- ? authenticationRequest.getRelyingPartyRegistrationId() : null;
|
|
|
|
- RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
|
|
|
|
- relyingPartyRegistrationId);
|
|
|
|
- if (relyingPartyRegistration == null) {
|
|
|
|
|
|
+ if (authenticationRequest == null) {
|
|
|
|
+ return null;
|
|
|
|
+ }
|
|
|
|
+ String registrationId = authenticationRequest.getRelyingPartyRegistrationId();
|
|
|
|
+ RelyingPartyRegistration registration = this.registrations.findByRegistrationId(registrationId);
|
|
|
|
+ return tokenByRegistration(request, registration, authenticationRequest);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private Saml2AuthenticationToken tokenByRegistrationId(HttpServletRequest request,
|
|
|
|
+ RequestMatcher.MatchResult result) {
|
|
|
|
+ String registrationId = result.getVariables().get("registrationId");
|
|
|
|
+ if (registrationId == null) {
|
|
return null;
|
|
return null;
|
|
}
|
|
}
|
|
- String saml2Response = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
|
|
|
|
- if (saml2Response == null) {
|
|
|
|
|
|
+ RelyingPartyRegistration registration = this.registrations.findByRegistrationId(registrationId);
|
|
|
|
+ return tokenByRegistration(request, registration, null);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private Saml2AuthenticationToken tokenByEntityId(HttpServletRequest request) {
|
|
|
|
+ String serialized = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
|
|
|
|
+ String decoded = new String(samlDecode(serialized), StandardCharsets.UTF_8);
|
|
|
|
+ Response response = parse(decoded);
|
|
|
|
+ String issuer = response.getIssuer().getValue();
|
|
|
|
+ RelyingPartyRegistration registration = this.registrations.findUniqueByAssertingPartyEntityId(issuer);
|
|
|
|
+ return tokenByRegistration(request, registration, null);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private Saml2AuthenticationToken tokenByRegistration(HttpServletRequest request,
|
|
|
|
+ RelyingPartyRegistration registration, AbstractSaml2AuthenticationRequest authenticationRequest) {
|
|
|
|
+ if (registration == null) {
|
|
return null;
|
|
return null;
|
|
}
|
|
}
|
|
- byte[] b = samlDecode(saml2Response);
|
|
|
|
- saml2Response = inflateIfRequired(request, b);
|
|
|
|
- return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest);
|
|
|
|
|
|
+ String serialized = request.getParameter(Saml2ParameterNames.SAML_RESPONSE);
|
|
|
|
+ String decoded = inflateIfRequired(request, samlDecode(serialized));
|
|
|
|
+ UriResolver resolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
|
|
|
|
+ registration = registration.mutate().entityId(resolver.resolve(registration.getEntityId()))
|
|
|
|
+ .assertionConsumerServiceLocation(resolver.resolve(registration.getAssertionConsumerServiceLocation()))
|
|
|
|
+ .build();
|
|
|
|
+ return new Saml2AuthenticationToken(registration, decoded, authenticationRequest);
|
|
}
|
|
}
|
|
|
|
|
|
/**
|
|
/**
|
|
@@ -100,6 +201,15 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
|
|
this.loader = authenticationRequestRepository::loadAuthenticationRequest;
|
|
this.loader = authenticationRequestRepository::loadAuthenticationRequest;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ /**
|
|
|
|
+ * Use the given {@link RequestMatcher} to match the request.
|
|
|
|
+ * @param requestMatcher the {@link RequestMatcher} to use
|
|
|
|
+ */
|
|
|
|
+ public void setRequestMatcher(RequestMatcher requestMatcher) {
|
|
|
|
+ Assert.notNull(requestMatcher, "requestMatcher cannot be null");
|
|
|
|
+ this.requestMatcher = requestMatcher;
|
|
|
|
+ }
|
|
|
|
+
|
|
private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
|
|
private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
|
|
return this.loader.apply(request);
|
|
return this.loader.apply(request);
|
|
}
|
|
}
|
|
@@ -136,6 +246,18 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ private Response parse(String request) throws Saml2Exception {
|
|
|
|
+ try {
|
|
|
|
+ Document document = this.parserPool
|
|
|
|
+ .parse(new ByteArrayInputStream(request.getBytes(StandardCharsets.UTF_8)));
|
|
|
|
+ Element element = document.getDocumentElement();
|
|
|
|
+ return (Response) this.unmarshaller.unmarshall(element);
|
|
|
|
+ }
|
|
|
|
+ catch (Exception ex) {
|
|
|
|
+ throw new Saml2Exception("Failed to deserialize LogoutRequest", ex);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
static class Base64Checker {
|
|
static class Base64Checker {
|
|
|
|
|
|
private static final int[] values = genValueMapping();
|
|
private static final int[] values = genValueMapping();
|