|
@@ -16,10 +16,6 @@
|
|
|
|
|
|
package org.springframework.security.saml2.provider.service.web;
|
|
package org.springframework.security.saml2.provider.service.web;
|
|
|
|
|
|
-import java.util.HashMap;
|
|
|
|
-import java.util.Map;
|
|
|
|
-import java.util.function.Function;
|
|
|
|
-
|
|
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
import org.apache.commons.logging.Log;
|
|
import org.apache.commons.logging.Log;
|
|
import org.apache.commons.logging.LogFactory;
|
|
import org.apache.commons.logging.LogFactory;
|
|
@@ -27,13 +23,10 @@ import org.apache.commons.logging.LogFactory;
|
|
import org.springframework.core.convert.converter.Converter;
|
|
import org.springframework.core.convert.converter.Converter;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
|
-import org.springframework.security.web.util.UrlUtils;
|
|
|
|
|
|
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
|
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
|
import org.springframework.security.web.util.matcher.RequestMatcher;
|
|
import org.springframework.security.web.util.matcher.RequestMatcher;
|
|
import org.springframework.util.Assert;
|
|
import org.springframework.util.Assert;
|
|
-import org.springframework.util.StringUtils;
|
|
|
|
-import org.springframework.web.util.UriComponents;
|
|
|
|
-import org.springframework.web.util.UriComponentsBuilder;
|
|
|
|
|
|
|
|
/**
|
|
/**
|
|
* A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the
|
|
* A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the
|
|
@@ -48,8 +41,6 @@ public final class DefaultRelyingPartyRegistrationResolver
|
|
|
|
|
|
private Log logger = LogFactory.getLog(getClass());
|
|
private Log logger = LogFactory.getLog(getClass());
|
|
|
|
|
|
- private static final char PATH_DELIMITER = '/';
|
|
|
|
-
|
|
|
|
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
|
|
private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
|
|
|
|
|
|
private final RequestMatcher registrationRequestMatcher = new AntPathRequestMatcher("/**/{registrationId}");
|
|
private final RequestMatcher registrationRequestMatcher = new AntPathRequestMatcher("/**/{registrationId}");
|
|
@@ -87,61 +78,19 @@ public final class DefaultRelyingPartyRegistrationResolver
|
|
}
|
|
}
|
|
return null;
|
|
return null;
|
|
}
|
|
}
|
|
- RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository
|
|
|
|
|
|
+ RelyingPartyRegistration registration = this.relyingPartyRegistrationRepository
|
|
.findByRegistrationId(relyingPartyRegistrationId);
|
|
.findByRegistrationId(relyingPartyRegistrationId);
|
|
- if (relyingPartyRegistration == null) {
|
|
|
|
|
|
+ if (registration == null) {
|
|
return null;
|
|
return null;
|
|
}
|
|
}
|
|
- String applicationUri = getApplicationUri(request);
|
|
|
|
- Function<String, String> templateResolver = templateResolver(applicationUri, relyingPartyRegistration);
|
|
|
|
- String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId());
|
|
|
|
- String assertionConsumerServiceLocation = templateResolver
|
|
|
|
- .apply(relyingPartyRegistration.getAssertionConsumerServiceLocation());
|
|
|
|
- String singleLogoutServiceLocation = templateResolver
|
|
|
|
- .apply(relyingPartyRegistration.getSingleLogoutServiceLocation());
|
|
|
|
- String singleLogoutServiceResponseLocation = templateResolver
|
|
|
|
- .apply(relyingPartyRegistration.getSingleLogoutServiceResponseLocation());
|
|
|
|
- return relyingPartyRegistration.mutate().entityId(relyingPartyEntityId)
|
|
|
|
- .assertionConsumerServiceLocation(assertionConsumerServiceLocation)
|
|
|
|
- .singleLogoutServiceLocation(singleLogoutServiceLocation)
|
|
|
|
- .singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation).build();
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
|
|
|
|
- return (template) -> resolveUrlTemplate(template, applicationUri, relyingParty);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
|
|
|
|
- if (template == null) {
|
|
|
|
- return null;
|
|
|
|
- }
|
|
|
|
- String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
|
|
|
|
- String registrationId = relyingParty.getRegistrationId();
|
|
|
|
- Map<String, String> uriVariables = new HashMap<>();
|
|
|
|
- UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl).replaceQuery(null).fragment(null)
|
|
|
|
|
|
+ UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
|
|
|
|
+ return registration.mutate().entityId(uriResolver.resolve(registration.getEntityId()))
|
|
|
|
+ .assertionConsumerServiceLocation(
|
|
|
|
+ uriResolver.resolve(registration.getAssertionConsumerServiceLocation()))
|
|
|
|
+ .singleLogoutServiceLocation(uriResolver.resolve(registration.getSingleLogoutServiceLocation()))
|
|
|
|
+ .singleLogoutServiceResponseLocation(
|
|
|
|
+ uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation()))
|
|
.build();
|
|
.build();
|
|
- String scheme = uriComponents.getScheme();
|
|
|
|
- uriVariables.put("baseScheme", (scheme != null) ? scheme : "");
|
|
|
|
- String host = uriComponents.getHost();
|
|
|
|
- uriVariables.put("baseHost", (host != null) ? host : "");
|
|
|
|
- // following logic is based on HierarchicalUriComponents#toUriString()
|
|
|
|
- int port = uriComponents.getPort();
|
|
|
|
- uriVariables.put("basePort", (port == -1) ? "" : ":" + port);
|
|
|
|
- String path = uriComponents.getPath();
|
|
|
|
- if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) {
|
|
|
|
- path = PATH_DELIMITER + path;
|
|
|
|
- }
|
|
|
|
- uriVariables.put("basePath", (path != null) ? path : "");
|
|
|
|
- uriVariables.put("baseUrl", uriComponents.toUriString());
|
|
|
|
- uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
|
|
|
|
- uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
|
|
|
|
- return UriComponentsBuilder.fromUriString(template).buildAndExpand(uriVariables).toUriString();
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- private static String getApplicationUri(HttpServletRequest request) {
|
|
|
|
- UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
|
|
|
|
- .replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build();
|
|
|
|
- return uriComponents.toUriString();
|
|
|
|
}
|
|
}
|
|
|
|
|
|
}
|
|
}
|