Explorar o código

Remove Unecessary Code

Josh Cummings hai 4 meses
pai
achega
5436fd5574

+ 5 - 10
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java

@@ -16,8 +16,6 @@
 
 package org.springframework.security.saml2.provider.service.web;
 
-import java.util.function.Function;
-
 import jakarta.servlet.http.HttpServletRequest;
 
 import org.springframework.http.HttpMethod;
@@ -43,7 +41,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 
 	private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
 
-	private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
+	private Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository;
 
 	/**
 	 * Constructs a {@link Saml2AuthenticationTokenConverter} given a strategy for
@@ -54,12 +52,13 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 	public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
 		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
 		this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
-		this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
+		this.authenticationRequestRepository = new HttpSessionSaml2AuthenticationRequestRepository();
 	}
 
 	@Override
 	public Saml2AuthenticationToken convert(HttpServletRequest request) {
-		AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.authenticationRequestRepository
+			.loadAuthenticationRequest(request);
 		String relyingPartyRegistrationId = (authenticationRequest != null)
 				? authenticationRequest.getRelyingPartyRegistrationId() : null;
 		RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
@@ -84,11 +83,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 	public void setAuthenticationRequestRepository(
 			Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository) {
 		Assert.notNull(authenticationRequestRepository, "authenticationRequestRepository cannot be null");
-		this.loader = authenticationRequestRepository::loadAuthenticationRequest;
-	}
-
-	private AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
-		return this.loader.apply(request);
+		this.authenticationRequestRepository = authenticationRequestRepository;
 	}
 
 	private String decode(HttpServletRequest request) {

+ 3 - 6
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilter.java

@@ -29,7 +29,6 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2A
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository;
-import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
@@ -77,9 +76,7 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
 	public Saml2WebSsoAuthenticationFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
 			String filterProcessesUrl) {
 		this(new Saml2AuthenticationTokenConverter(
-				(RelyingPartyRegistrationResolver) new DefaultRelyingPartyRegistrationResolver(
-						relyingPartyRegistrationRepository)),
-				filterProcessesUrl);
+				new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), filterProcessesUrl);
 		Assert.isTrue(filterProcessesUrl.contains("{registrationId}"),
 				"filterProcessesUrl must contain a {registrationId} match variable");
 	}
@@ -159,9 +156,9 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
 	}
 
 	private void setDetails(HttpServletRequest request, Authentication authentication) {
-		if (AbstractAuthenticationToken.class.isAssignableFrom(authentication.getClass())) {
+		if (authentication instanceof AbstractAuthenticationToken token) {
 			Object details = this.authenticationDetailsSource.buildDetails(request);
-			((AbstractAuthenticationToken) authentication).setDetails(details);
+			token.setDetails(details);
 		}
 	}
 

+ 23 - 3
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/Saml2WebSsoAuthenticationFilterTests.java

@@ -16,6 +16,7 @@
 
 package org.springframework.security.saml2.provider.service.web.authentication;
 
+import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -94,16 +95,18 @@ public class Saml2WebSsoAuthenticationFilterTests {
 
 	@Test
 	public void requiresAuthenticationWhenHappyPathThenReturnsTrue() {
-		assertThat(this.filter.requiresAuthentication(this.request, this.response)).isTrue();
+		RequiresAuthenticationExposingFilter filter = new RequiresAuthenticationExposingFilter(this.repository);
+		assertThat(filter.requiresAuthentication(this.request, this.response)).isTrue();
 	}
 
 	@Test
 	public void requiresAuthenticationWhenCustomProcessingUrlThenReturnsTrue() {
-		this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/some/other/path/{registrationId}");
+		RequiresAuthenticationExposingFilter filter = new RequiresAuthenticationExposingFilter(this.repository,
+				"/some/other/path/{registrationId}");
 		this.request.setRequestURI("/some/other/path/idp-registration-id");
 		this.request.setPathInfo("/some/other/path/idp-registration-id");
 		this.request.setParameter(Saml2ParameterNames.SAML_RESPONSE, "xml-data-goes-here");
-		assertThat(this.filter.requiresAuthentication(this.request, this.response)).isTrue();
+		assertThat(filter.requiresAuthentication(this.request, this.response)).isTrue();
 	}
 
 	@Test
@@ -212,4 +215,21 @@ public class Saml2WebSsoAuthenticationFilterTests {
 		verify(this.repository).findByRegistrationId("registration-id");
 	}
 
+	static final class RequiresAuthenticationExposingFilter extends Saml2WebSsoAuthenticationFilter {
+
+		RequiresAuthenticationExposingFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
+			super(relyingPartyRegistrationRepository);
+		}
+
+		RequiresAuthenticationExposingFilter(RelyingPartyRegistrationRepository registrations, String url) {
+			super(registrations, url);
+		}
+
+		@Override
+		protected boolean requiresAuthentication(HttpServletRequest request, HttpServletResponse response) {
+			return super.requiresAuthentication(request, response);
+		}
+
+	}
+
 }