浏览代码

Polish OneTimeTokenLoginConfigurer

Signed-off-by: DingHao <dh.hiekn@gmail.com>
DingHao 7 月之前
父节点
当前提交
f7e0f7fa8a

+ 21 - 30
config/src/main/java/org/springframework/security/config/annotation/web/configurers/ott/OneTimeTokenLoginConfigurer.java

@@ -18,7 +18,6 @@ package org.springframework.security.config.annotation.web.configurers.ott;
 
 import java.util.Collections;
 import java.util.Map;
-import java.util.Objects;
 
 import jakarta.servlet.http.HttpServletRequest;
 
@@ -91,7 +90,7 @@ public final class OneTimeTokenLoginConfigurer<H extends HttpSecurityBuilder<H>>
 
 	@Override
 	public void init(H http) {
-		AuthenticationProvider authenticationProvider = getAuthenticationProvider(http);
+		AuthenticationProvider authenticationProvider = getAuthenticationProvider();
 		http.authenticationProvider(postProcess(authenticationProvider));
 		configureDefaultLoginPage(http);
 	}
@@ -138,17 +137,19 @@ public final class OneTimeTokenLoginConfigurer<H extends HttpSecurityBuilder<H>>
 	}
 
 	private void configureOttGenerateFilter(H http) {
-		GenerateOneTimeTokenFilter generateFilter = new GenerateOneTimeTokenFilter(getOneTimeTokenService(http),
-				getOneTimeTokenGenerationSuccessHandler(http));
+		GenerateOneTimeTokenFilter generateFilter = new GenerateOneTimeTokenFilter(getOneTimeTokenService(),
+				getOneTimeTokenGenerationSuccessHandler());
 		generateFilter.setRequestMatcher(antMatcher(HttpMethod.POST, this.tokenGeneratingUrl));
-		generateFilter.setRequestResolver(getGenerateRequestResolver(http));
+		generateFilter.setRequestResolver(getGenerateRequestResolver());
 		http.addFilter(postProcess(generateFilter));
 		http.addFilter(DefaultResourcesFilter.css());
 	}
 
-	private OneTimeTokenGenerationSuccessHandler getOneTimeTokenGenerationSuccessHandler(H http) {
+	private OneTimeTokenGenerationSuccessHandler getOneTimeTokenGenerationSuccessHandler() {
 		if (this.oneTimeTokenGenerationSuccessHandler == null) {
-			this.oneTimeTokenGenerationSuccessHandler = getBeanOrNull(http, OneTimeTokenGenerationSuccessHandler.class);
+			this.oneTimeTokenGenerationSuccessHandler = this.context
+				.getBeanProvider(OneTimeTokenGenerationSuccessHandler.class)
+				.getIfUnique();
 		}
 		if (this.oneTimeTokenGenerationSuccessHandler == null) {
 			throw new IllegalStateException("""
@@ -170,12 +171,12 @@ public final class OneTimeTokenLoginConfigurer<H extends HttpSecurityBuilder<H>>
 		http.addFilter(postProcess(submitPage));
 	}
 
-	private AuthenticationProvider getAuthenticationProvider(H http) {
+	private AuthenticationProvider getAuthenticationProvider() {
 		if (this.authenticationProvider != null) {
 			return this.authenticationProvider;
 		}
-		UserDetailsService userDetailsService = getContext().getBean(UserDetailsService.class);
-		this.authenticationProvider = new OneTimeTokenAuthenticationProvider(getOneTimeTokenService(http),
+		UserDetailsService userDetailsService = this.context.getBean(UserDetailsService.class);
+		this.authenticationProvider = new OneTimeTokenAuthenticationProvider(getOneTimeTokenService(),
 				userDetailsService);
 		return this.authenticationProvider;
 	}
@@ -321,44 +322,34 @@ public final class OneTimeTokenLoginConfigurer<H extends HttpSecurityBuilder<H>>
 		return this;
 	}
 
-	private GenerateOneTimeTokenRequestResolver getGenerateRequestResolver(H http) {
+	private GenerateOneTimeTokenRequestResolver getGenerateRequestResolver() {
 		if (this.requestResolver != null) {
 			return this.requestResolver;
 		}
-		GenerateOneTimeTokenRequestResolver bean = getBeanOrNull(http, GenerateOneTimeTokenRequestResolver.class);
-		this.requestResolver = Objects.requireNonNullElseGet(bean, DefaultGenerateOneTimeTokenRequestResolver::new);
+		this.requestResolver = this.context.getBeanProvider(GenerateOneTimeTokenRequestResolver.class)
+			.getIfUnique(DefaultGenerateOneTimeTokenRequestResolver::new);
 		return this.requestResolver;
 	}
 
-	private OneTimeTokenService getOneTimeTokenService(H http) {
+	private OneTimeTokenService getOneTimeTokenService() {
 		if (this.oneTimeTokenService != null) {
 			return this.oneTimeTokenService;
 		}
-		OneTimeTokenService bean = getBeanOrNull(http, OneTimeTokenService.class);
-		if (bean != null) {
-			this.oneTimeTokenService = bean;
-		}
-		else {
-			this.oneTimeTokenService = new InMemoryOneTimeTokenService();
-		}
+		this.oneTimeTokenService = this.context.getBeanProvider(OneTimeTokenService.class)
+			.getIfUnique(InMemoryOneTimeTokenService::new);
 		return this.oneTimeTokenService;
 	}
 
-	private <C> C getBeanOrNull(H http, Class<C> clazz) {
-		ApplicationContext context = http.getSharedObject(ApplicationContext.class);
-		if (context == null) {
-			return null;
-		}
-
-		return context.getBeanProvider(clazz).getIfUnique();
-	}
-
 	private Map<String, String> hiddenInputs(HttpServletRequest request) {
 		CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
 		return (token != null) ? Collections.singletonMap(token.getParameterName(), token.getToken())
 				: Collections.emptyMap();
 	}
 
+	/**
+	 * @deprecated Use this.context instead
+	 */
+	@Deprecated
 	public ApplicationContext getContext() {
 		return this.context;
 	}