瀏覽代碼

Polish Saml2WebSsoAuthenticationRequestFilter

- Updated formatting
- Reordered methods
- Removed a method

These changes will hopefully simplify future contribution.

Issue gh-6019
Josh Cummings 5 年之前
父節點
當前提交
95f0d02d79

+ 53 - 64
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java

@@ -35,14 +35,13 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher.MatchResult;
 import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
 import org.springframework.web.filter.OncePerRequestFilter;
 import org.springframework.web.util.HtmlUtils;
 import org.springframework.web.util.UriComponentsBuilder;
 import org.springframework.web.util.UriUtils;
 
-import static java.lang.String.format;
 import static java.nio.charset.StandardCharsets.ISO_8859_1;
-import static org.springframework.util.StringUtils.hasText;
 
 /**
  * This {@code Filter} formulates a
@@ -128,6 +127,7 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
+
 		MatchResult matcher = this.redirectMatcher.matcher(request);
 		if (!matcher.isMatch()) {
 			filterChain.doFilter(request, response);
@@ -135,65 +135,28 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 		}
 
 		String registrationId = matcher.getVariables().get("registrationId");
-		RelyingPartyRegistration relyingParty = this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
+		RelyingPartyRegistration relyingParty =
+				this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
 		if (relyingParty == null) {
 			response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
 			return;
 		}
 		if (this.logger.isDebugEnabled()) {
-			this.logger.debug(format("Creating SAML2 SP Authentication Request for IDP[%s]", relyingParty.getRegistrationId()));
+			this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" +
+					relyingParty.getRegistrationId() + "]");
 		}
-		Saml2AuthenticationRequestContext authnRequestCtx = createRedirectAuthenticationRequestContext(relyingParty, request);
+		Saml2AuthenticationRequestContext context = createRedirectAuthenticationRequestContext(request, relyingParty);
 		if (relyingParty.getProviderDetails().getBinding() == Saml2MessageBinding.REDIRECT) {
-			sendRedirect(response, authnRequestCtx);
+			sendRedirect(response, context);
 		}
 		else {
-			sendPost(response, authnRequestCtx);
-		}
-	}
-
-	private void sendRedirect(HttpServletResponse response, Saml2AuthenticationRequestContext authnRequestCtx)
-			throws IOException {
-		String redirectUrl = createSamlRequestRedirectUrl(authnRequestCtx);
-		response.sendRedirect(redirectUrl);
-	}
-
-	private void sendPost(HttpServletResponse response, Saml2AuthenticationRequestContext authnRequestCtx)
-			throws IOException {
-		Saml2PostAuthenticationRequest authNData =
-				this.authenticationRequestFactory.createPostAuthenticationRequest(authnRequestCtx);
-		String html = createSamlPostRequestFormData(authNData);
-		response.setContentType(MediaType.TEXT_HTML_VALUE);
-		response.getWriter().write(html);
-	}
-
-	private String createSamlRequestRedirectUrl(Saml2AuthenticationRequestContext authnRequestCtx) {
-
-		Saml2RedirectAuthenticationRequest authNData =
-				this.authenticationRequestFactory.createRedirectAuthenticationRequest(authnRequestCtx);
-		UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(authNData.getAuthenticationRequestUri());
-		addParameter("SAMLRequest", authNData.getSamlRequest(), uriBuilder);
-		addParameter("RelayState", authNData.getRelayState(), uriBuilder);
-		addParameter("SigAlg", authNData.getSigAlg(), uriBuilder);
-		addParameter("Signature", authNData.getSignature(), uriBuilder);
-		return uriBuilder
-				.build(true)
-				.toUriString();
-	}
-
-	private void addParameter(String name, String value, UriComponentsBuilder builder) {
-		Assert.hasText(name, "name cannot be empty or null");
-		if (hasText(value)) {
-			builder.queryParam(
-					UriUtils.encode(name, ISO_8859_1),
-					UriUtils.encode(value, ISO_8859_1)
-			);
+			sendPost(response, context);
 		}
 	}
 
 	private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext(
-			RelyingPartyRegistration relyingParty,
-			HttpServletRequest request) {
+			HttpServletRequest request, RelyingPartyRegistration relyingParty) {
+
 		String applicationUri = Saml2ServletUtils.getApplicationUri(request);
 		Function<String, String> resolver = templateResolver(applicationUri, relyingParty);
 		String localSpEntityId = resolver.apply(relyingParty.getLocalEntityIdTemplate());
@@ -210,17 +173,45 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 		return template -> Saml2ServletUtils.resolveUrlTemplate(template, applicationUri, relyingParty);
 	}
 
-	private String htmlEscape(String value) {
-		if (hasText(value)) {
-			return HtmlUtils.htmlEscape(value);
+	private void sendRedirect(HttpServletResponse response, Saml2AuthenticationRequestContext context)
+			throws IOException {
+		Saml2RedirectAuthenticationRequest authenticationRequest =
+				this.authenticationRequestFactory.createRedirectAuthenticationRequest(context);
+		UriComponentsBuilder uriBuilder = UriComponentsBuilder
+				.fromUriString(authenticationRequest.getAuthenticationRequestUri());
+		addParameter("SAMLRequest", authenticationRequest.getSamlRequest(), uriBuilder);
+		addParameter("RelayState", authenticationRequest.getRelayState(), uriBuilder);
+		addParameter("SigAlg", authenticationRequest.getSigAlg(), uriBuilder);
+		addParameter("Signature", authenticationRequest.getSignature(), uriBuilder);
+		String redirectUrl = uriBuilder
+				.build(true)
+				.toUriString();
+		response.sendRedirect(redirectUrl);
+	}
+
+	private void addParameter(String name, String value, UriComponentsBuilder builder) {
+		Assert.hasText(name, "name cannot be empty or null");
+		if (StringUtils.hasText(value)) {
+			builder.queryParam(
+					UriUtils.encode(name, ISO_8859_1),
+					UriUtils.encode(value, ISO_8859_1)
+			);
 		}
-		return value;
 	}
 
-	private String createSamlPostRequestFormData(Saml2PostAuthenticationRequest request) {
-		String destination = request.getAuthenticationRequestUri();
-		String relayState = htmlEscape(request.getRelayState());
-		String samlRequest = htmlEscape(request.getSamlRequest());
+	private void sendPost(HttpServletResponse response, Saml2AuthenticationRequestContext context)
+			throws IOException {
+		Saml2PostAuthenticationRequest authenticationRequest =
+				this.authenticationRequestFactory.createPostAuthenticationRequest(context);
+		String html = createSamlPostRequestFormData(authenticationRequest);
+		response.setContentType(MediaType.TEXT_HTML_VALUE);
+		response.getWriter().write(html);
+	}
+
+	private String createSamlPostRequestFormData(Saml2PostAuthenticationRequest authenticationRequest) {
+		String authenticationRequestUri = authenticationRequest.getAuthenticationRequestUri();
+		String relayState = authenticationRequest.getRelayState();
+		String samlRequest = authenticationRequest.getSamlRequest();
 		StringBuilder postHtml = new StringBuilder()
 				.append("<!DOCTYPE html>\n")
 				.append("<html>\n")
@@ -235,16 +226,15 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 				.append("            </p>\n")
 				.append("        </noscript>\n")
 				.append("        \n")
-				.append("        <form action=\"").append(destination).append("\" method=\"post\">\n")
+				.append("        <form action=\"").append(authenticationRequestUri).append("\" method=\"post\">\n")
 				.append("            <div>\n")
 				.append("                <input type=\"hidden\" name=\"SAMLRequest\" value=\"")
-				.append(samlRequest)
-				.append("\"/>\n")
-				;
-		if (hasText(relayState)) {
+				.append(HtmlUtils.htmlEscape(samlRequest))
+				.append("\"/>\n");
+		if (StringUtils.hasText(relayState)) {
 			postHtml
 					.append("                <input type=\"hidden\" name=\"RelayState\" value=\"")
-					.append(relayState)
+					.append(HtmlUtils.htmlEscape(relayState))
 					.append("\"/>\n");
 		}
 		postHtml
@@ -257,8 +247,7 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 				.append("        </form>\n")
 				.append("        \n")
 				.append("    </body>\n")
-				.append("</html>")
-		;
+				.append("</html>");
 		return postHtml.toString();
 	}
 }