Przeglądaj źródła

Make Internal Logout URI Configurable

Closes gh-14609
Josh Cummings 1 rok temu
rodzic
commit
662cfed349

+ 31 - 8
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcBackChannelLogoutHandler.java

@@ -19,6 +19,7 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.cl
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.Map;
 
 import jakarta.servlet.http.HttpServletRequest;
@@ -37,10 +38,12 @@ import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegist
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
 import org.springframework.security.web.authentication.logout.LogoutHandler;
+import org.springframework.security.web.util.UrlUtils;
 import org.springframework.util.Assert;
 import org.springframework.web.client.RestClientException;
 import org.springframework.web.client.RestOperations;
 import org.springframework.web.client.RestTemplate;
+import org.springframework.web.util.UriComponents;
 import org.springframework.web.util.UriComponentsBuilder;
 
 /**
@@ -61,7 +64,7 @@ final class OidcBackChannelLogoutHandler implements LogoutHandler {
 
 	private RestOperations restOperations = new RestTemplate();
 
-	private String logoutEndpointName = "/logout";
+	private String logoutUri = "{baseScheme}://localhost{basePort}/logout";
 
 	private String sessionCookieName = "JSESSIONID";
 
@@ -112,12 +115,32 @@ final class OidcBackChannelLogoutHandler implements LogoutHandler {
 	}
 
 	String computeLogoutEndpoint(HttpServletRequest request) {
-		String url = request.getRequestURL().toString();
-		return UriComponentsBuilder.fromHttpUrl(url)
-			.host("localhost")
-			.replacePath(this.logoutEndpointName)
-			.build()
-			.toUriString();
+		// @formatter:off
+		UriComponents uriComponents = UriComponentsBuilder
+				.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
+				.replacePath(request.getContextPath())
+				.replaceQuery(null)
+				.fragment(null)
+				.build();
+
+		Map<String, String> uriVariables = new HashMap<>();
+		String scheme = uriComponents.getScheme();
+		uriVariables.put("baseScheme", (scheme != null) ? scheme : "");
+		uriVariables.put("baseUrl", uriComponents.toUriString());
+
+		String host = uriComponents.getHost();
+		uriVariables.put("baseHost", (host != null) ? host : "");
+
+		String path = uriComponents.getPath();
+		uriVariables.put("basePath", (path != null) ? path : "");
+
+		int port = uriComponents.getPort();
+		uriVariables.put("basePort", (port == -1) ? "" : ":" + port);
+
+		return UriComponentsBuilder.fromUriString(this.logoutUri)
+				.buildAndExpand(uriVariables)
+				.toUriString();
+		// @formatter:on
 	}
 
 	private OAuth2Error oauth2Error(Collection<String> errors) {
@@ -164,7 +187,7 @@ final class OidcBackChannelLogoutHandler implements LogoutHandler {
 	 */
 	void setLogoutUri(String logoutUri) {
 		Assert.hasText(logoutUri, "logoutUri cannot be empty");
-		this.logoutEndpointName = logoutUri;
+		this.logoutUri = logoutUri;
 	}
 
 	/**

+ 43 - 6
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurer.java

@@ -17,6 +17,7 @@
 package org.springframework.security.config.annotation.web.configurers.oauth2.client;
 
 import java.util.function.Consumer;
+import java.util.function.Function;
 
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.ProviderManager;
@@ -123,7 +124,7 @@ public final class OidcLogoutConfigurer<B extends HttpSecurityBuilder<B>>
 		private final AuthenticationManager authenticationManager = new ProviderManager(
 				new OidcBackChannelLogoutAuthenticationProvider());
 
-		private LogoutHandler logoutHandler;
+		private Function<B, LogoutHandler> logoutHandler = this::logoutHandler;
 
 		private AuthenticationConverter authenticationConverter(B http) {
 			if (this.authenticationConverter == null) {
@@ -139,18 +140,54 @@ public final class OidcLogoutConfigurer<B extends HttpSecurityBuilder<B>>
 		}
 
 		private LogoutHandler logoutHandler(B http) {
-			if (this.logoutHandler == null) {
+			OidcBackChannelLogoutHandler logoutHandler = new OidcBackChannelLogoutHandler();
+			logoutHandler.setSessionRegistry(OAuth2ClientConfigurerUtils.getOidcSessionRegistry(http));
+			return logoutHandler;
+		}
+
+		/**
+		 * Use this endpoint when invoking a back-channel logout.
+		 *
+		 * <p>
+		 * The resulting {@link LogoutHandler} will {@code POST} the session cookie and
+		 * CSRF token to this endpoint to invalidate the corresponding end-user session.
+		 *
+		 * <p>
+		 * Supports URI templates like {@code {baseUrl}}, {@code {baseScheme}}, and
+		 * {@code {basePort}}.
+		 *
+		 * <p>
+		 * By default, the URI is set to
+		 * {@code {baseScheme}://localhost{basePort}/logout}, meaning that the scheme and
+		 * port of the original back-channel request is preserved, while the host and
+		 * endpoint are changed.
+		 *
+		 * <p>
+		 * If you are using Spring Security for the logout endpoint, the path part of this
+		 * URI should match the value configured there.
+		 *
+		 * <p>
+		 * Otherwise, this is handy in the event that your server configuration means that
+		 * the scheme, server name, or port in the {@code Host} header are different from
+		 * how you would address the same server internally.
+		 * @param logoutUri the URI to request logout on the back-channel
+		 * @return the {@link BackChannelLogoutConfigurer} for further customizations
+		 * @since 6.2.4
+		 */
+		public BackChannelLogoutConfigurer logoutUri(String logoutUri) {
+			this.logoutHandler = (http) -> {
 				OidcBackChannelLogoutHandler logoutHandler = new OidcBackChannelLogoutHandler();
 				logoutHandler.setSessionRegistry(OAuth2ClientConfigurerUtils.getOidcSessionRegistry(http));
-				this.logoutHandler = logoutHandler;
-			}
-			return this.logoutHandler;
+				logoutHandler.setLogoutUri(logoutUri);
+				return logoutHandler;
+			};
+			return this;
 		}
 
 		void configure(B http) {
 			OidcBackChannelLogoutFilter filter = new OidcBackChannelLogoutFilter(authenticationConverter(http),
 					authenticationManager());
-			filter.setLogoutHandler(logoutHandler(http));
+			filter.setLogoutHandler(this.logoutHandler.apply(http));
 			http.addFilterBefore(filter, CsrfFilter.class);
 		}
 

+ 32 - 10
config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelServerLogoutHandler.java

@@ -18,6 +18,7 @@ package org.springframework.security.config.web.server;
 
 import java.nio.charset.StandardCharsets;
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -30,6 +31,7 @@ import reactor.core.publisher.Mono;
 import org.springframework.core.io.buffer.DataBuffer;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.ResponseEntity;
+import org.springframework.http.server.reactive.ServerHttpRequest;
 import org.springframework.http.server.reactive.ServerHttpResponse;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken;
@@ -42,6 +44,7 @@ import org.springframework.security.web.server.WebFilterExchange;
 import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
 import org.springframework.util.Assert;
 import org.springframework.web.reactive.function.client.WebClient;
+import org.springframework.web.util.UriComponents;
 import org.springframework.web.util.UriComponentsBuilder;
 
 /**
@@ -62,7 +65,7 @@ final class OidcBackChannelServerLogoutHandler implements ServerLogoutHandler {
 
 	private WebClient web = WebClient.create();
 
-	private String logoutEndpointName = "/logout";
+	private String logoutUri = "{baseScheme}://localhost{basePort}/logout";
 
 	private String sessionCookieName = "SESSION";
 
@@ -108,17 +111,36 @@ final class OidcBackChannelServerLogoutHandler implements ServerLogoutHandler {
 		for (Map.Entry<String, String> credential : session.getAuthorities().entrySet()) {
 			headers.add(credential.getKey(), credential.getValue());
 		}
-		String logout = computeLogoutEndpoint(exchange);
+		String logout = computeLogoutEndpoint(exchange.getExchange().getRequest());
 		return this.web.post().uri(logout).headers((h) -> h.putAll(headers)).retrieve().toBodilessEntity();
 	}
 
-	String computeLogoutEndpoint(WebFilterExchange exchange) {
-		String url = exchange.getExchange().getRequest().getURI().toString();
-		return UriComponentsBuilder.fromHttpUrl(url)
-			.host("localhost")
-			.replacePath(this.logoutEndpointName)
-			.build()
-			.toUriString();
+	String computeLogoutEndpoint(ServerHttpRequest request) {
+		// @formatter:off
+		UriComponents uriComponents = UriComponentsBuilder.fromUri(request.getURI())
+				.replacePath(request.getPath().contextPath().value())
+				.replaceQuery(null)
+				.fragment(null)
+				.build();
+
+		Map<String, String> uriVariables = new HashMap<>();
+		String scheme = uriComponents.getScheme();
+		uriVariables.put("baseScheme", (scheme != null) ? scheme : "");
+		uriVariables.put("baseUrl", uriComponents.toUriString());
+
+		String host = uriComponents.getHost();
+		uriVariables.put("baseHost", (host != null) ? host : "");
+
+		String path = uriComponents.getPath();
+		uriVariables.put("basePath", (path != null) ? path : "");
+
+		int port = uriComponents.getPort();
+		uriVariables.put("basePort", (port == -1) ? "" : ":" + port);
+
+		return UriComponentsBuilder.fromUriString(this.logoutUri)
+				.buildAndExpand(uriVariables)
+				.toUriString();
+		// @formatter:on
 	}
 
 	private OAuth2Error oauth2Error(Collection<?> errors) {
@@ -168,7 +190,7 @@ final class OidcBackChannelServerLogoutHandler implements ServerLogoutHandler {
 	 */
 	void setLogoutUri(String logoutUri) {
 		Assert.hasText(logoutUri, "logoutUri cannot be empty");
-		this.logoutEndpointName = logoutUri;
+		this.logoutUri = logoutUri;
 	}
 
 	/**

+ 46 - 6
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -58,6 +58,7 @@ import org.springframework.security.authorization.AuthorizationDecision;
 import org.springframework.security.authorization.ObservationReactiveAuthorizationManager;
 import org.springframework.security.authorization.ReactiveAuthorizationManager;
 import org.springframework.security.config.Customizer;
+import org.springframework.security.config.annotation.web.configurers.oauth2.client.OidcLogoutConfigurer;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.AuthorityUtils;
@@ -111,6 +112,7 @@ import org.springframework.security.oauth2.server.resource.web.access.server.Bea
 import org.springframework.security.oauth2.server.resource.web.server.BearerTokenServerAuthenticationEntryPoint;
 import org.springframework.security.oauth2.server.resource.web.server.authentication.ServerBearerTokenAuthenticationConverter;
 import org.springframework.security.web.PortMapper;
+import org.springframework.security.web.authentication.logout.LogoutHandler;
 import org.springframework.security.web.authentication.preauth.x509.SubjectDnX509PrincipalExtractor;
 import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
 import org.springframework.security.web.server.DefaultServerRedirectStrategy;
@@ -5484,7 +5486,7 @@ public class ServerHttpSecurity {
 
 			private final ReactiveAuthenticationManager authenticationManager = new OidcBackChannelLogoutReactiveAuthenticationManager();
 
-			private ServerLogoutHandler logoutHandler;
+			private Supplier<ServerLogoutHandler> logoutHandler = this::logoutHandler;
 
 			private ServerAuthenticationConverter authenticationConverter() {
 				if (this.authenticationConverter == null) {
@@ -5499,18 +5501,56 @@ public class ServerHttpSecurity {
 			}
 
 			private ServerLogoutHandler logoutHandler() {
-				if (this.logoutHandler == null) {
+				OidcBackChannelServerLogoutHandler logoutHandler = new OidcBackChannelServerLogoutHandler();
+				logoutHandler.setSessionRegistry(OidcLogoutSpec.this.getSessionRegistry());
+				return logoutHandler;
+			}
+
+			/**
+			 * Use this endpoint when invoking a back-channel logout.
+			 *
+			 * <p>
+			 * The resulting {@link LogoutHandler} will {@code POST} the session cookie
+			 * and CSRF token to this endpoint to invalidate the corresponding end-user
+			 * session.
+			 *
+			 * <p>
+			 * Supports URI templates like {@code {baseUrl}}, {@code {baseScheme}}, and
+			 * {@code {basePort}}.
+			 *
+			 * <p>
+			 * By default, the URI is set to
+			 * {@code {baseScheme}://localhost{basePort}/logout}, meaning that the scheme
+			 * and port of the original back-channel request is preserved, while the host
+			 * and endpoint are changed.
+			 *
+			 * <p>
+			 * If you are using Spring Security for the logout endpoint, the path part of
+			 * this URI should match the value configured there.
+			 *
+			 * <p>
+			 * Otherwise, this is handy in the event that your server configuration means
+			 * that the scheme, server name, or port in the {@code Host} header are
+			 * different from how you would address the same server internally.
+			 * @param logoutUri the URI to request logout on the back-channel
+			 * @return the {@link OidcLogoutConfigurer.BackChannelLogoutConfigurer} for
+			 * further customizations
+			 * @since 6.2.4
+			 */
+			public BackChannelLogoutConfigurer logoutUri(String logoutUri) {
+				this.logoutHandler = () -> {
 					OidcBackChannelServerLogoutHandler logoutHandler = new OidcBackChannelServerLogoutHandler();
 					logoutHandler.setSessionRegistry(OidcLogoutSpec.this.getSessionRegistry());
-					this.logoutHandler = logoutHandler;
-				}
-				return this.logoutHandler;
+					logoutHandler.setLogoutUri(logoutUri);
+					return logoutHandler;
+				};
+				return this;
 			}
 
 			void configure(ServerHttpSecurity http) {
 				OidcBackChannelLogoutWebFilter filter = new OidcBackChannelLogoutWebFilter(authenticationConverter(),
 						authenticationManager());
-				filter.setLogoutHandler(logoutHandler());
+				filter.setLogoutHandler(this.logoutHandler.get());
 				http.addFilterBefore(filter, SecurityWebFiltersOrder.CSRF);
 			}
 

+ 24 - 0
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcBackChannelLogoutHandlerTests.java

@@ -35,4 +35,28 @@ public class OidcBackChannelLogoutHandlerTests {
 		assertThat(endpoint).isEqualTo("http://localhost:8090/logout");
 	}
 
+	@Test
+	public void computeLogoutEndpointWhenUsingBaseUrlTemplateThenServerName() {
+		OidcBackChannelLogoutHandler logoutHandler = new OidcBackChannelLogoutHandler();
+		logoutHandler.setLogoutUri("{baseUrl}/logout");
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", "/back-channel/logout");
+		request.setServerName("host.docker.internal");
+		request.setServerPort(8090);
+		String endpoint = logoutHandler.computeLogoutEndpoint(request);
+		assertThat(endpoint).isEqualTo("http://host.docker.internal:8090/logout");
+	}
+
+	// gh-14609
+	@Test
+	public void computeLogoutEndpointWhenLogoutUriThenUses() {
+		OidcBackChannelLogoutHandler logoutHandler = new OidcBackChannelLogoutHandler();
+		logoutHandler.setLogoutUri("http://localhost:8090/logout");
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", "/back-channel/logout");
+		request.setScheme("https");
+		request.setServerName("server-one.com");
+		request.setServerPort(80);
+		String endpoint = logoutHandler.computeLogoutEndpoint(request);
+		assertThat(endpoint).isEqualTo("http://localhost:8090/logout");
+	}
+
 }

+ 42 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OidcLogoutConfigurerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2023 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -197,6 +197,25 @@ public class OidcLogoutConfigurerTests {
 		this.mvc.perform(get("/token/logout").session(three)).andExpect(status().isUnauthorized());
 	}
 
+	@Test
+	void logoutWhenRemoteLogoutUriThenUses() throws Exception {
+		this.spring.register(WebServerConfig.class, OidcProviderConfig.class, LogoutUriConfig.class).autowire();
+		String registrationId = this.clientRegistration.getRegistrationId();
+		MockHttpSession one = login();
+		String logoutToken = this.mvc.perform(get("/token/logout/all").session(one))
+			.andExpect(status().isOk())
+			.andReturn()
+			.getResponse()
+			.getContentAsString();
+		this.mvc
+			.perform(post(this.web.url("/logout/connect/back-channel/" + registrationId).toString())
+				.param("logout_token", logoutToken))
+			.andExpect(status().isBadRequest())
+			.andExpect(content().string(containsString("partial_logout")))
+			.andExpect(content().string(containsString("Connection refused")));
+		this.mvc.perform(get("/token/logout").session(one)).andExpect(status().isOk());
+	}
+
 	@Test
 	void logoutWhenRemoteLogoutFailsThenReportsPartialLogout() throws Exception {
 		this.spring.register(WebServerConfig.class, OidcProviderConfig.class, WithBrokenLogoutConfig.class).autowire();
@@ -312,6 +331,28 @@ public class OidcLogoutConfigurerTests {
 
 	}
 
+	@Configuration
+	@EnableWebSecurity
+	@Import(RegistrationConfig.class)
+	static class LogoutUriConfig {
+
+		@Bean
+		@Order(1)
+		SecurityFilterChain filters(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.authorizeHttpRequests((authorize) -> authorize.anyRequest().authenticated())
+					.oauth2Login(Customizer.withDefaults())
+					.oidcLogout((oidc) -> oidc
+						.backChannel((backchannel) -> backchannel.logoutUri("http://localhost/wrong"))
+					);
+			// @formatter:on
+
+			return http.build();
+		}
+
+	}
+
 	@Configuration
 	@EnableWebSecurity
 	@Import(RegistrationConfig.class)

+ 21 - 0
config/src/test/java/org/springframework/security/config/web/server/OidcBackChannelServerLogoutHandlerTests.java

@@ -38,4 +38,25 @@ public class OidcBackChannelServerLogoutHandlerTests {
 		assertThat(endpoint).isEqualTo("https://localhost:8090/logout");
 	}
 
+	@Test
+	public void computeLogoutEndpointWhenUsingBaseUrlTemplateThenServerName() {
+		OidcBackChannelServerLogoutHandler logoutHandler = new OidcBackChannelServerLogoutHandler();
+		logoutHandler.setLogoutUri("{baseUrl}/logout");
+		MockServerHttpRequest request = MockServerHttpRequest
+			.get("http://host.docker.internal:8090/back-channel/logout")
+			.build();
+		String endpoint = logoutHandler.computeLogoutEndpoint(request);
+		assertThat(endpoint).isEqualTo("http://host.docker.internal:8090/logout");
+	}
+
+	// gh-14609
+	@Test
+	public void computeLogoutEndpointWhenLogoutUriThenUses() {
+		OidcBackChannelServerLogoutHandler logoutHandler = new OidcBackChannelServerLogoutHandler();
+		logoutHandler.setLogoutUri("http://localhost:8090/logout");
+		MockServerHttpRequest request = MockServerHttpRequest.get("https://server-one.com/back-channel/logout").build();
+		String endpoint = logoutHandler.computeLogoutEndpoint(request);
+		assertThat(endpoint).isEqualTo("http://localhost:8090/logout");
+	}
+
 }

+ 49 - 1
config/src/test/java/org/springframework/security/config/web/server/OidcLogoutSpecTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -242,6 +242,32 @@ public class OidcLogoutSpecTests {
 		this.test.get().uri("/token/logout").cookie("SESSION", three).exchange().expectStatus().isUnauthorized();
 	}
 
+	@Test
+	void logoutWhenRemoteLogoutUriThenUses() {
+		this.spring.register(WebServerConfig.class, OidcProviderConfig.class, LogoutUriConfig.class).autowire();
+		String registrationId = this.clientRegistration.getRegistrationId();
+		String one = login();
+		String logoutToken = this.test.get()
+			.uri("/token/logout/all")
+			.cookie("SESSION", one)
+			.exchange()
+			.expectStatus()
+			.isOk()
+			.returnResult(String.class)
+			.getResponseBody()
+			.blockFirst();
+		this.test.post()
+			.uri(this.web.url("/logout/connect/back-channel/" + registrationId).toString())
+			.body(BodyInserters.fromFormData("logout_token", logoutToken))
+			.exchange()
+			.expectStatus()
+			.isBadRequest()
+			.expectBody(String.class)
+			.value(containsString("partial_logout"))
+			.value(containsString("Connection refused"));
+		this.test.get().uri("/token/logout").cookie("SESSION", one).exchange().expectStatus().isOk();
+	}
+
 	@Test
 	void logoutWhenRemoteLogoutFailsThenReportsPartialLogout() {
 		this.spring.register(WebServerConfig.class, OidcProviderConfig.class, WithBrokenLogoutConfig.class).autowire();
@@ -396,6 +422,28 @@ public class OidcLogoutSpecTests {
 
 	}
 
+	@Configuration
+	@EnableWebFluxSecurity
+	@Import(RegistrationConfig.class)
+	static class LogoutUriConfig {
+
+		@Bean
+		@Order(1)
+		SecurityWebFilterChain filters(ServerHttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.authorizeExchange((authorize) -> authorize.anyExchange().authenticated())
+					.oauth2Login(Customizer.withDefaults())
+					.oidcLogout((oidc) -> oidc
+						.backChannel((backchannel) -> backchannel.logoutUri("http://localhost/wrong"))
+					);
+			// @formatter:on
+
+			return http.build();
+		}
+
+	}
+
 	@Configuration
 	@EnableWebFluxSecurity
 	@Import(RegistrationConfig.class)