Prechádzať zdrojové kódy

feat: Add option to specify a custom ServerAuthenticationConverter for x509()

Signed-off-by: blake_bauman <blake_bauman@apple.com>
blake_bauman 3 mesiacov pred
rodič
commit
b502697731

+ 16 - 1
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -3231,6 +3231,8 @@ public class ServerHttpSecurity {
 
 		private ReactiveAuthenticationManager authenticationManager;
 
+		private ServerAuthenticationConverter serverAuthenticationConverter;
+
 		private X509Spec() {
 		}
 
@@ -3244,11 +3246,17 @@ public class ServerHttpSecurity {
 			return this;
 		}
 
+		public X509Spec serverAuthenticationConverter(ServerAuthenticationConverter serverAuthenticationConverter) {
+			this.serverAuthenticationConverter = serverAuthenticationConverter;
+			return this;
+		}
+
 		protected void configure(ServerHttpSecurity http) {
 			ReactiveAuthenticationManager authenticationManager = getAuthenticationManager();
 			X509PrincipalExtractor principalExtractor = getPrincipalExtractor();
+			ServerAuthenticationConverter converter = getServerAuthenticationConverter(principalExtractor);
 			AuthenticationWebFilter filter = new AuthenticationWebFilter(authenticationManager);
-			filter.setServerAuthenticationConverter(new ServerX509AuthenticationConverter(principalExtractor));
+			filter.setServerAuthenticationConverter(serverAuthenticationConverter);
 			http.addFilterAt(filter, SecurityWebFiltersOrder.AUTHENTICATION);
 		}
 
@@ -3267,6 +3275,13 @@ public class ServerHttpSecurity {
 			return new ReactivePreAuthenticatedAuthenticationManager(userDetailsService);
 		}
 
+		private ServerAuthenticationConverter getServerAuthenticationConverter(X509PrincipalExtractor extractor) {
+			if (this.serverAuthenticationConverter != null) {
+				return this.serverAuthenticationConverter;
+			}
+			return new ServerX509AuthenticationConverter(extractor);
+		}
+
 	}
 
 	public final class OAuth2LoginSpec {

+ 23 - 0
config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java

@@ -60,6 +60,7 @@ import org.springframework.security.web.server.authentication.AnonymousAuthentic
 import org.springframework.security.web.server.authentication.DelegatingServerAuthenticationSuccessHandler;
 import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
 import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint;
+import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
 import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
 import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
 import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
@@ -497,6 +498,17 @@ public class ServerHttpSecurityTests {
 		assertThat(x509WebFilter).isNotNull();
 	}
 
+	@Test
+	public void x509WithConverterAndNoExtractorThenAddsX509Filter() {
+		ServerAuthenticationConverter mockConverter = mock(ServerAuthenticationConverter.class);
+		this.http.x509((x509) -> x509.serverAuthenticationConverter(mockConverter));
+		SecurityWebFilterChain securityWebFilterChain = this.http.build();
+		WebFilter x509WebFilter = securityWebFilterChain.getWebFilters()
+			.filter(filter -> matchesX509Converter(filter, mockConverter))
+			.blockFirst();
+		assertThat(x509WebFilter).isNotNull();
+	}
+
 	@Test
 	public void addsX509FilterWhenX509AuthenticationIsConfiguredWithDefaults() {
 		this.http.x509(withDefaults());
@@ -769,6 +781,17 @@ public class ServerHttpSecurityTests {
 		}
 	}
 
+	private boolean matchesX509Converter(WebFilter filter, ServerAuthenticationConverter expectedConverter) {
+		try {
+			Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");
+			return converter.equals(expectedConverter);
+		}
+		catch (IllegalArgumentException ex) {
+			// field doesn't exist
+			return false;
+		}
+	}
+
 	private <T extends WebFilter> Optional<T> getWebFilter(SecurityWebFilterChain filterChain, Class<T> filterClass) {
 		return (Optional<T>) filterChain.getWebFilters()
 			.filter(Objects::nonNull)