ソースを参照

Reactive Oidc RP-Initiated Logout

Issue: gh-5350
Josh Cummings 6 年 前
コミット
fba31dfb6a

+ 96 - 7
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -16,10 +16,18 @@
 
 package org.springframework.security.config.web.server;
 
+import java.time.Instant;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
 import org.junit.Rule;
 import org.junit.Test;
 import org.openqa.selenium.WebDriver;
+import reactor.core.publisher.Mono;
+
 import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.ApplicationContext;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.security.authentication.ReactiveAuthenticationManager;
@@ -27,15 +35,22 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
 import org.springframework.security.config.oauth2.client.CommonOAuth2Provider;
 import org.springframework.security.config.test.SpringTestRule;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeReactiveAuthenticationManager;
+import org.springframework.security.oauth2.client.web.server.oidc.logout.OidcClientInitiatedServerLogoutSuccessHandler;
 import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -60,21 +75,21 @@ import org.springframework.security.test.web.reactive.server.WebTestClientBuilde
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.WebFilterChainProxy;
 import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
+import org.springframework.security.web.server.context.ServerSecurityContextRepository;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.test.web.reactive.server.WebTestClient;
+import org.springframework.web.reactive.config.EnableWebFlux;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilterChain;
-import reactor.core.publisher.Mono;
-
-import java.time.Instant;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
+import org.springframework.web.server.WebHandler;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 /**
  * @author Rob Winch
@@ -85,6 +100,8 @@ public class OAuth2LoginTests {
 	@Rule
 	public final SpringTestRule spring = new SpringTestRule();
 
+	private WebTestClient client;
+
 	@Autowired
 	private WebFilterChainProxy springSecurity;
 
@@ -100,6 +117,14 @@ public class OAuth2LoginTests {
 			.clientSecret("secret")
 			.build();
 
+	@Autowired
+	public void setApplicationContext(ApplicationContext context) {
+		if (context.getBeanNamesForType(WebHandler.class).length > 0) {
+			this.client = WebTestClient.bindToApplicationContext(context)
+					.build();
+		}
+	}
+
 	@Test
 	public void defaultLoginPageWithMultipleClientRegistrationsThenLinks() {
 		this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class).autowire();
@@ -326,6 +351,60 @@ public class OAuth2LoginTests {
 		}
 	}
 
+
+	@Test
+	public void logoutWhenUsingOidcLogoutHandlerThenRedirects() throws Exception {
+		this.spring.register(OAuth2LoginConfigWithOidcLogoutSuccessHandler.class).autowire();
+
+		OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(
+				TestOidcUsers.create(),
+				AuthorityUtils.NO_AUTHORITIES,
+				getBean(ClientRegistration.class).getRegistrationId());
+
+		ServerSecurityContextRepository repository = getBean(ServerSecurityContextRepository.class);
+		when(repository.load(any())).thenReturn(authentication(token));
+
+		this.client.post().uri("/logout")
+				.exchange()
+				.expectHeader().valueEquals("Location", "http://logout?id_token_hint=id-token");
+	}
+
+	@EnableWebFlux
+	@EnableWebFluxSecurity
+	static class OAuth2LoginConfigWithOidcLogoutSuccessHandler {
+		private final ServerSecurityContextRepository repository =
+				mock(ServerSecurityContextRepository.class);
+		private final ClientRegistration withLogout =
+				TestClientRegistrations.clientRegistration()
+						.providerConfigurationMetadata(Collections.singletonMap(
+								"end_session_endpoint", "http://logout")).build();
+
+		@Bean
+		public SecurityWebFilterChain springSecurity(ServerHttpSecurity http) {
+
+			http
+				.csrf().disable()
+				.logout()
+					.logoutSuccessHandler(
+							new OidcClientInitiatedServerLogoutSuccessHandler(
+									new InMemoryReactiveClientRegistrationRepository(this.withLogout)))
+					.and()
+				.securityContextRepository(this.repository);
+
+			return http.build();
+		}
+
+		@Bean
+		ServerSecurityContextRepository securityContextRepository() {
+			return this.repository;
+		}
+
+		@Bean
+		ClientRegistration clientRegistration() {
+			return this.withLogout;
+		}
+	}
+
 	static class GitHubWebFilter implements WebFilter {
 
 		@Override
@@ -336,4 +415,14 @@ public class OAuth2LoginTests {
 			return chain.filter(exchange);
 		}
 	}
+
+	Mono<SecurityContext> authentication(Authentication authentication) {
+		SecurityContext context = new SecurityContextImpl();
+		context.setAuthentication(authentication);
+		return Mono.just(context);
+	}
+
+	<T> T getBean(Class<T> beanClass) {
+		return this.spring.getContext().getBean(beanClass);
+	}
 }

+ 126 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/oidc/logout/OidcClientInitiatedServerLogoutSuccessHandler.java

@@ -0,0 +1,126 @@
+/*
+ * Copyright 2002-2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.server.oidc.logout;
+
+import java.net.URI;
+import java.nio.charset.StandardCharsets;
+
+import reactor.core.publisher.Mono;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.core.oidc.user.OidcUser;
+import org.springframework.security.web.server.DefaultServerRedirectStrategy;
+import org.springframework.security.web.server.ServerRedirectStrategy;
+import org.springframework.security.web.server.WebFilterExchange;
+import org.springframework.security.web.server.authentication.logout.RedirectServerLogoutSuccessHandler;
+import org.springframework.security.web.server.authentication.logout.ServerLogoutSuccessHandler;
+import org.springframework.util.Assert;
+import org.springframework.web.util.UriComponentsBuilder;
+
+/**
+ * A reactive logout success handler for initiating OIDC logout through the user agent.
+ *
+ * @author Josh Cummings
+ * @since 5.2
+ * @see <a href="http://openid.net/specs/openid-connect-session-1_0.html#RPLogout">RP-Initiated Logout</a>
+ * @see org.springframework.security.web.server.authentication.logout.ServerLogoutSuccessHandler
+ */
+public class OidcClientInitiatedServerLogoutSuccessHandler
+		implements ServerLogoutSuccessHandler {
+
+	private final ServerRedirectStrategy redirectStrategy = new DefaultServerRedirectStrategy();
+	private final RedirectServerLogoutSuccessHandler serverLogoutSuccessHandler
+			= new RedirectServerLogoutSuccessHandler();
+	private final ReactiveClientRegistrationRepository clientRegistrationRepository;
+
+	private URI postLogoutRedirectUri;
+
+	/**
+	 * Constructs an {@link OidcClientInitiatedServerLogoutSuccessHandler} with the provided parameters
+	 *
+	 * @param clientRegistrationRepository The {@link ReactiveClientRegistrationRepository} to use to derive
+	 * the end_session_endpoint value
+	 */
+	public OidcClientInitiatedServerLogoutSuccessHandler
+			(ReactiveClientRegistrationRepository clientRegistrationRepository) {
+
+		Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
+		this.clientRegistrationRepository = clientRegistrationRepository;
+	}
+
+	/**
+	 * {@inheritDoc}
+	 */
+	@Override
+	public Mono<Void> onLogoutSuccess(WebFilterExchange exchange, Authentication authentication) {
+		return Mono.just(authentication)
+				.filter(OAuth2AuthenticationToken.class::isInstance)
+				.filter(token -> authentication.getPrincipal() instanceof OidcUser)
+				.map(OAuth2AuthenticationToken.class::cast)
+				.flatMap(this::endSessionEndpoint)
+				.map(endSessionEndpoint -> endpointUri(endSessionEndpoint, authentication))
+				.switchIfEmpty(this.serverLogoutSuccessHandler
+						.onLogoutSuccess(exchange, authentication).then(Mono.empty()))
+				.flatMap(endpointUri -> this.redirectStrategy.sendRedirect(exchange.getExchange(), endpointUri));
+	}
+
+	private Mono<URI> endSessionEndpoint(OAuth2AuthenticationToken token) {
+		String registrationId = token.getAuthorizedClientRegistrationId();
+		return this.clientRegistrationRepository.findByRegistrationId(registrationId)
+				.map(ClientRegistration::getProviderDetails)
+				.map(ClientRegistration.ProviderDetails::getConfigurationMetadata)
+				.flatMap(configurationMetadata -> Mono.justOrEmpty(configurationMetadata.get("end_session_endpoint")))
+				.map(Object::toString)
+				.map(URI::create);
+	}
+
+	private URI endpointUri(URI endSessionEndpoint, Authentication authentication) {
+		UriComponentsBuilder builder = UriComponentsBuilder.fromUri(endSessionEndpoint);
+		builder.queryParam("id_token_hint", idToken(authentication));
+		if (this.postLogoutRedirectUri != null) {
+			builder.queryParam("post_logout_redirect_uri", this.postLogoutRedirectUri);
+		}
+		return builder.encode(StandardCharsets.UTF_8).build().toUri();
+	}
+
+	private String idToken(Authentication authentication) {
+		return ((OidcUser) authentication.getPrincipal()).getIdToken().getTokenValue();
+	}
+
+	/**
+	 * Set the post logout redirect uri to use
+	 *
+	 * @param postLogoutRedirectUri - A valid URL to which the OP should redirect after logging out the user
+	 */
+	public void setPostLogoutRedirectUri(URI postLogoutRedirectUri) {
+		Assert.notNull(postLogoutRedirectUri, "postLogoutRedirectUri cannot be empty");
+		this.postLogoutRedirectUri = postLogoutRedirectUri;
+	}
+
+	/**
+	 * The URL to redirect to after successfully logging out when not originally an OIDC login
+	 *
+	 * @param logoutSuccessUrl the url to redirect to. Default is "/login?logout".
+	 */
+	public void setLogoutSuccessUrl(URI logoutSuccessUrl) {
+		Assert.notNull(logoutSuccessUrl, "logoutSuccessUrl cannot be null");
+		this.serverLogoutSuccessHandler.setLogoutSuccessUrl(logoutSuccessUrl);
+	}
+}

+ 163 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/oidc/logout/OidcClientInitiatedServerLogoutSuccessHandlerTests.java

@@ -0,0 +1,163 @@
+/*
+ * Copyright 2002-2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web.server.oidc.logout;
+
+import java.net.URI;
+import java.util.Collections;
+
+import org.junit.Before;
+import org.junit.Test;
+import reactor.core.publisher.Mono;
+
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.core.oidc.user.TestOidcUsers;
+import org.springframework.security.oauth2.core.user.TestOAuth2Users;
+import org.springframework.security.web.server.WebFilterExchange;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebFilterChain;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class OidcClientInitiatedServerLogoutSuccessHandlerTests {
+	ClientRegistration registration = TestClientRegistrations
+			.clientRegistration()
+			.providerConfigurationMetadata(
+					Collections.singletonMap("end_session_endpoint", "http://endpoint"))
+			.build();
+	ReactiveClientRegistrationRepository repository = new InMemoryReactiveClientRegistrationRepository(registration);
+
+	ServerWebExchange exchange;
+	WebFilterChain chain;
+
+	OidcClientInitiatedServerLogoutSuccessHandler handler;
+
+	@Before
+	public void setup() {
+		this.exchange = mock(ServerWebExchange.class);
+		when(this.exchange.getResponse()).thenReturn(new MockServerHttpResponse());
+		when(this.exchange.getRequest()).thenReturn(MockServerHttpRequest.get("/").build());
+		this.chain = mock(WebFilterChain.class);
+		this.handler = new OidcClientInitiatedServerLogoutSuccessHandler(this.repository);
+	}
+
+	@Test
+	public void logoutWhenOidcRedirectUrlConfiguredThenRedirects() {
+		OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(
+				TestOidcUsers.create(),
+				AuthorityUtils.NO_AUTHORITIES,
+				this.registration.getRegistrationId());
+
+		when(this.exchange.getPrincipal()).thenReturn(Mono.just(token));
+		WebFilterExchange f = new WebFilterExchange(exchange, this.chain);
+		this.handler.onLogoutSuccess(f, token).block();
+
+		assertThat(redirectedUrl(this.exchange)).isEqualTo("http://endpoint?id_token_hint=id-token");
+	}
+
+	@Test
+	public void logoutWhenNotOAuth2AuthenticationThenDefaults() {
+		Authentication token = mock(Authentication.class);
+
+		when(this.exchange.getPrincipal()).thenReturn(Mono.just(token));
+		WebFilterExchange f = new WebFilterExchange(exchange, this.chain);
+
+		this.handler.setLogoutSuccessUrl(URI.create("http://default"));
+		this.handler.onLogoutSuccess(f, token).block();
+
+		assertThat(redirectedUrl(this.exchange)).isEqualTo("http://default");
+	}
+
+	@Test
+	public void logoutWhenNotOidcUserThenDefaults() {
+		OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(
+				TestOAuth2Users.create(),
+				AuthorityUtils.NO_AUTHORITIES,
+				this.registration.getRegistrationId());
+
+		when(this.exchange.getPrincipal()).thenReturn(Mono.just(token));
+		WebFilterExchange f = new WebFilterExchange(exchange, this.chain);
+
+		this.handler.setLogoutSuccessUrl(URI.create("http://default"));
+		this.handler.onLogoutSuccess(f, token).block();
+
+		assertThat(redirectedUrl(this.exchange)).isEqualTo("http://default");
+	}
+
+	@Test
+	public void logoutWhenClientRegistrationHasNoEndSessionEndpointThenDefaults() {
+
+		ClientRegistration registration = TestClientRegistrations.clientRegistration().build();
+		ReactiveClientRegistrationRepository repository =
+				new InMemoryReactiveClientRegistrationRepository(registration);
+		OidcClientInitiatedServerLogoutSuccessHandler handler =
+				new OidcClientInitiatedServerLogoutSuccessHandler(repository);
+
+		OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(
+				TestOidcUsers.create(),
+				AuthorityUtils.NO_AUTHORITIES,
+				registration.getRegistrationId());
+
+		when(this.exchange.getPrincipal()).thenReturn(Mono.just(token));
+		WebFilterExchange f = new WebFilterExchange(exchange, this.chain);
+
+		handler.setLogoutSuccessUrl(URI.create("http://default"));
+		handler.onLogoutSuccess(f, token).block();
+
+		assertThat(redirectedUrl(this.exchange)).isEqualTo("http://default");
+	}
+
+	@Test
+	public void logoutWhenUsingPostLogoutRedirectUriThenIncludesItInRedirect() {
+
+		OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(
+				TestOidcUsers.create(),
+				AuthorityUtils.NO_AUTHORITIES,
+				this.registration.getRegistrationId());
+
+		when(this.exchange.getPrincipal()).thenReturn(Mono.just(token));
+		WebFilterExchange f = new WebFilterExchange(exchange, this.chain);
+
+		this.handler.setPostLogoutRedirectUri(URI.create("http://postlogout?encodedparam=value"));
+		this.handler.onLogoutSuccess(f, token).block();
+
+		assertThat(redirectedUrl(this.exchange))
+				.isEqualTo("http://endpoint?" +
+						"id_token_hint=id-token&" +
+						"post_logout_redirect_uri=http://postlogout?encodedparam%3Dvalue");
+	}
+
+	@Test
+	public void setPostLogoutRedirectUriWhenGivenNullThenThrowsException() {
+		assertThatThrownBy(() -> this.handler.setPostLogoutRedirectUri(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	private String redirectedUrl(ServerWebExchange exchange) {
+		return exchange.getResponse().getHeaders().getFirst("Location");
+	}
+}