Browse Source

WebFlux oauth2Login() redirects on failed authentication

Fixes gh-5562 gh-6484
Joe Grandja 5 years ago
parent
commit
459e8f1a11

+ 3 - 13
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -53,7 +53,6 @@ import org.springframework.security.authorization.AuthorityReactiveAuthorization
 import org.springframework.security.authorization.AuthorizationDecision;
 import org.springframework.security.authorization.ReactiveAuthorizationManager;
 import org.springframework.security.core.Authentication;
-import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeReactiveAuthenticationManager;
@@ -89,7 +88,6 @@ import org.springframework.security.web.server.DelegatingServerAuthenticationEnt
 import org.springframework.security.web.server.MatcherSecurityWebFilterChain;
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.ServerAuthenticationEntryPoint;
-import org.springframework.security.web.server.WebFilterExchange;
 import org.springframework.security.web.server.authentication.AuthenticationWebFilter;
 import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
 import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint;
@@ -619,16 +617,8 @@ public class ServerHttpSecurity {
 			AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, authorizedClientRepository);
 			authenticationFilter.setRequiresAuthenticationMatcher(createAttemptAuthenticationRequestMatcher());
 			authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));
-			RedirectServerAuthenticationSuccessHandler redirectHandler = new RedirectServerAuthenticationSuccessHandler();
-
-			authenticationFilter.setAuthenticationSuccessHandler(redirectHandler);
-			authenticationFilter.setAuthenticationFailureHandler(new ServerAuthenticationFailureHandler() {
-				@Override
-				public Mono<Void> onAuthenticationFailure(WebFilterExchange webFilterExchange,
-						AuthenticationException exception) {
-					return Mono.error(exception);
-				}
-			});
+			authenticationFilter.setAuthenticationSuccessHandler(new RedirectServerAuthenticationSuccessHandler());
+			authenticationFilter.setAuthenticationFailureHandler(new RedirectServerAuthenticationFailureHandler("/login?error"));
 			authenticationFilter.setSecurityContextRepository(new WebSessionServerSecurityContextRepository());
 
 			MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher(

+ 95 - 6
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -34,13 +34,26 @@ import org.springframework.security.config.annotation.web.reactive.EnableWebFlux
 import org.springframework.security.config.oauth2.client.CommonOAuth2Provider;
 import org.springframework.security.config.test.SpringTestRule;
 import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeReactiveAuthenticationManager;
+import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges;
+import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
+import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
+import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.core.user.TestOAuth2Users;
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
@@ -54,6 +67,9 @@ import org.springframework.web.server.WebFilterChain;
 
 import reactor.core.publisher.Mono;
 
+import java.time.Duration;
+import java.time.Instant;
+
 /**
  * @author Rob Winch
  * @since 5.1
@@ -72,6 +88,12 @@ public class OAuth2LoginTests {
 			.clientSecret("secret")
 			.build();
 
+	private static ClientRegistration google = CommonOAuth2Provider.GOOGLE
+			.getBuilder("google")
+			.clientId("client")
+			.clientSecret("secret")
+			.build();
+
 	@Test
 	public void defaultLoginPageWithMultipleClientRegistrationsThenLinks() {
 		this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class).autowire();
@@ -97,11 +119,6 @@ public class OAuth2LoginTests {
 	static class OAuth2LoginWithMulitpleClientRegistrations {
 		@Bean
 		InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() {
-			ClientRegistration google = CommonOAuth2Provider.GOOGLE
-					.getBuilder("google")
-					.clientId("client")
-					.clientSecret("secret")
-					.build();
 			return new InMemoryReactiveClientRegistrationRepository(github, google);
 		}
 	}
@@ -182,6 +199,78 @@ public class OAuth2LoginTests {
 		}
 	}
 
+	// gh-5562
+	@Test
+	public void oauth2LoginWhenAccessTokenRequestFailsThenDefaultRedirectToLogin() {
+		this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class,
+				OAuth2LoginWithCustomBeansConfig.class).autowire();
+
+		WebTestClient webTestClient = WebTestClientBuilder
+				.bindToWebFilters(this.springSecurity)
+				.build();
+
+		OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests.request().scope("openid").build();
+		OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.success().build();
+		OAuth2AuthorizationExchange exchange = new OAuth2AuthorizationExchange(request, response);
+		OAuth2AccessToken accessToken =  new OAuth2AccessToken(
+				OAuth2AccessToken.TokenType.BEARER, "openid", Instant.now(), Instant.now().plus(Duration.ofDays(1)));
+		OAuth2AuthorizationCodeAuthenticationToken authenticationToken =
+				new OAuth2AuthorizationCodeAuthenticationToken(google, exchange, accessToken);
+
+		OAuth2LoginWithCustomBeansConfig config = this.spring.getContext().getBean(OAuth2LoginWithCustomBeansConfig.class);
+
+		ServerAuthenticationConverter converter = config.authenticationConverter;
+		when(converter.convert(any())).thenReturn(Mono.just(authenticationToken));
+
+		ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> tokenResponseClient = config.tokenResponseClient;
+		OAuth2Error oauth2Error = new OAuth2Error("invalid_request", "Invalid request", null);
+		when(tokenResponseClient.getTokenResponse(any())).thenThrow(new OAuth2AuthenticationException(oauth2Error));
+
+		webTestClient.get()
+				.uri("/login/oauth2/code/google")
+				.exchange()
+				.expectStatus()
+					.is3xxRedirection()
+				.expectHeader()
+					.valueEquals("Location", "/login?error");
+	}
+
+	@Configuration
+	static class OAuth2LoginWithCustomBeansConfig {
+
+		ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class);
+
+		ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> tokenResponseClient =
+				mock(ReactiveOAuth2AccessTokenResponseClient.class);
+
+		ReactiveOAuth2UserService<OidcUserRequest, OidcUser> userService = mock(ReactiveOAuth2UserService.class);
+
+		@Bean
+		public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
+			// @formatter:off
+			http
+				.authorizeExchange()
+					.anyExchange().authenticated()
+					.and()
+				.oauth2Login()
+					.authenticationConverter(authenticationConverter)
+					.authenticationManager(authenticationManager());
+			return http.build();
+			// @formatter:on
+		}
+
+		private ReactiveAuthenticationManager authenticationManager() {
+			OidcAuthorizationCodeReactiveAuthenticationManager oidc =
+					new OidcAuthorizationCodeReactiveAuthenticationManager(tokenResponseClient, userService);
+			return oidc;
+		}
+
+		@Bean
+		public ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient() {
+			return tokenResponseClient;
+		}
+	}
+
 	static class GitHubWebFilter implements WebFilter {
 
 		@Override