소스 검색

Pick Up OidcSessionRegistry Bean

Closes gh-15813
Josh Cummings 11 달 전
부모
커밋
b311b811a1

+ 9 - 2
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerUtils.java

@@ -116,10 +116,17 @@ final class OAuth2ClientConfigurerUtils {
 
 	static <B extends HttpSecurityBuilder<B>> OidcSessionRegistry getOidcSessionRegistry(B builder) {
 		OidcSessionRegistry sessionRegistry = builder.getSharedObject(OidcSessionRegistry.class);
-		if (sessionRegistry == null) {
+		if (sessionRegistry != null) {
+			return sessionRegistry;
+		}
+		ApplicationContext context = builder.getSharedObject(ApplicationContext.class);
+		if (context.getBeanNamesForType(OidcSessionRegistry.class).length == 1) {
+			sessionRegistry = context.getBean(OidcSessionRegistry.class);
+		}
+		else {
 			sessionRegistry = new InMemoryOidcSessionRegistry();
-			builder.setSharedObject(OidcSessionRegistry.class, sessionRegistry);
 		}
+		builder.setSharedObject(OidcSessionRegistry.class, sessionRegistry);
 		return sessionRegistry;
 	}
 

+ 1 - 1
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -5496,7 +5496,7 @@ public class ServerHttpSecurity {
 
 		private ReactiveOidcSessionRegistry getSessionRegistry() {
 			if (this.sessionRegistry == null && ServerHttpSecurity.this.oauth2Login == null) {
-				return new InMemoryReactiveOidcSessionRegistry();
+				return getBeanOrDefault(ReactiveOidcSessionRegistry.class, new InMemoryReactiveOidcSessionRegistry());
 			}
 			if (this.sessionRegistry == null) {
 				return ServerHttpSecurity.this.oauth2Login.oidcSessionRegistry;

+ 10 - 14
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java

@@ -396,15 +396,13 @@ public class OidcLogoutConfigurerTests {
 	@Import(RegistrationConfig.class)
 	static class SelfLogoutUriConfig {
 
-		private final OidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry();
-
 		@Bean
 		@Order(1)
 		SecurityFilterChain filters(HttpSecurity http) throws Exception {
 			// @formatter:off
 			http
 				.authorizeHttpRequests((authorize) -> authorize.anyRequest().authenticated())
-				.oauth2Login((oauth2) -> oauth2.oidcSessionRegistry(this.sessionRegistry))
+				.oauth2Login(Customizer.withDefaults())
 				.oidcLogout((oidc) -> oidc
 					.backChannel(Customizer.withDefaults())
 				);
@@ -413,11 +411,6 @@ public class OidcLogoutConfigurerTests {
 			return http.build();
 		}
 
-		@Bean
-		OidcBackChannelLogoutHandler oidcLogoutHandler() {
-			return new OidcBackChannelLogoutHandler(this.sessionRegistry);
-		}
-
 	}
 
 	@Configuration
@@ -427,15 +420,13 @@ public class OidcLogoutConfigurerTests {
 
 		private final MockWebServer server = new MockWebServer();
 
-		private final OidcSessionRegistry sessionRegistry = new InMemoryOidcSessionRegistry();
-
 		@Bean
 		@Order(1)
 		SecurityFilterChain filters(HttpSecurity http) throws Exception {
 			// @formatter:off
 			http
 				.authorizeHttpRequests((authorize) -> authorize.anyRequest().authenticated())
-				.oauth2Login((oauth2) -> oauth2.oidcSessionRegistry(this.sessionRegistry))
+				.oauth2Login(Customizer.withDefaults())
 				.oidcLogout((oidc) -> oidc
 					.backChannel(Customizer.withDefaults())
 				);
@@ -445,8 +436,13 @@ public class OidcLogoutConfigurerTests {
 		}
 
 		@Bean
-		OidcBackChannelLogoutHandler oidcLogoutHandler() {
-			OidcBackChannelLogoutHandler logoutHandler = new OidcBackChannelLogoutHandler(this.sessionRegistry);
+		OidcSessionRegistry sessionRegistry() {
+			return new InMemoryOidcSessionRegistry();
+		}
+
+		@Bean
+		OidcBackChannelLogoutHandler oidcLogoutHandler(OidcSessionRegistry sessionRegistry) {
+			OidcBackChannelLogoutHandler logoutHandler = new OidcBackChannelLogoutHandler(sessionRegistry);
 			logoutHandler.setSessionCookieName("SESSION");
 			return logoutHandler;
 		}
@@ -485,7 +481,7 @@ public class OidcLogoutConfigurerTests {
 			// @formatter:off
 			http
 				.authorizeHttpRequests((authorize) -> authorize.anyRequest().authenticated())
-				.oauth2Login((oauth2) -> oauth2.oidcSessionRegistry(this.sessionRegistry))
+				.oauth2Login(Customizer.withDefaults())
 				.oidcLogout((oidc) -> oidc.backChannel(Customizer.withDefaults()));
 			// @formatter:on
 

+ 9 - 7
config/src/test/java/org/springframework/security/config/web/server/OidcLogoutSpecTests.java

@@ -519,8 +519,6 @@ public class OidcLogoutSpecTests {
 	@Import(RegistrationConfig.class)
 	static class CookieConfig {
 
-		private final ReactiveOidcSessionRegistry sessionRegistry = new InMemoryReactiveOidcSessionRegistry();
-
 		private final MockWebServer server = new MockWebServer();
 
 		@Bean
@@ -529,7 +527,7 @@ public class OidcLogoutSpecTests {
 			// @formatter:off
 			http
 				.authorizeExchange((authorize) -> authorize.anyExchange().authenticated())
-				.oauth2Login((oauth2) -> oauth2.oidcSessionRegistry(this.sessionRegistry))
+				.oauth2Login(Customizer.withDefaults())
 				.oidcLogout((oidc) -> oidc
 					.backChannel(Customizer.withDefaults())
 				);
@@ -539,9 +537,13 @@ public class OidcLogoutSpecTests {
 		}
 
 		@Bean
-		OidcBackChannelServerLogoutHandler oidcLogoutHandler() {
-			OidcBackChannelServerLogoutHandler logoutHandler = new OidcBackChannelServerLogoutHandler(
-					this.sessionRegistry);
+		ReactiveOidcSessionRegistry oidcSessionRegistry() {
+			return new InMemoryReactiveOidcSessionRegistry();
+		}
+
+		@Bean
+		OidcBackChannelServerLogoutHandler oidcLogoutHandler(ReactiveOidcSessionRegistry sessionRegistry) {
+			OidcBackChannelServerLogoutHandler logoutHandler = new OidcBackChannelServerLogoutHandler(sessionRegistry);
 			logoutHandler.setSessionCookieName("JSESSIONID");
 			return logoutHandler;
 		}
@@ -580,7 +582,7 @@ public class OidcLogoutSpecTests {
 		// @formatter:off
 			http
 					.authorizeExchange((authorize) -> authorize.anyExchange().authenticated())
-					.oauth2Login((oauth2) -> oauth2.oidcSessionRegistry(this.sessionRegistry))
+					.oauth2Login(Customizer.withDefaults())
 					.oidcLogout((oidc) -> oidc.backChannel(Customizer.withDefaults()));
 			// @formatter:on