Browse Source

OAuth2LoginAuthenticationWebFilter should handle OAuth2AuthorizationException

Issue gh-8609
Joe Grandja 5 years ago
parent
commit
674e2c0a8e

+ 8 - 2
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -33,6 +33,8 @@ import java.util.function.Supplier;
 
 import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import reactor.core.publisher.Mono;
 import reactor.util.context.Context;
 
@@ -1089,8 +1091,12 @@ public class ServerHttpSecurity {
 
 		private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) {
 			if (this.authenticationConverter == null) {
-				ServerOAuth2AuthorizationCodeAuthenticationTokenConverter authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
-				authenticationConverter.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
+				ServerOAuth2AuthorizationCodeAuthenticationTokenConverter delegate =
+						new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
+				delegate.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
+				ServerAuthenticationConverter authenticationConverter = exchange ->
+						delegate.convert(exchange).onErrorMap(OAuth2AuthorizationException.class,
+								e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()));
 				this.authenticationConverter = authenticationConverter;
 			}
 			return this.authenticationConverter;

+ 22 - 2
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -103,7 +103,10 @@ import java.util.Map;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 import static org.springframework.security.oauth2.jwt.TestJwts.jwt;
 
 /**
@@ -683,7 +686,6 @@ public class OAuth2LoginTests {
 		}
 	}
 
-
 	@Test
 	public void logoutWhenUsingOidcLogoutHandlerThenRedirects() {
 		this.spring.register(OAuth2LoginConfigWithOidcLogoutSuccessHandler.class).autowire();
@@ -739,6 +741,24 @@ public class OAuth2LoginTests {
 		}
 	}
 
+	// gh-8609
+	@Test
+	public void oauth2LoginWhenAuthenticationConverterFailsThenDefaultRedirectToLogin() {
+		this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class).autowire();
+
+		WebTestClient webTestClient = WebTestClientBuilder
+				.bindToWebFilters(this.springSecurity)
+				.build();
+
+		webTestClient.get()
+				.uri("/login/oauth2/code/google")
+				.exchange()
+				.expectStatus()
+				.is3xxRedirection()
+				.expectHeader()
+				.valueEquals("Location", "/login?error");
+	}
+
 	static class GitHubWebFilter implements WebFilter {
 
 		@Override

+ 7 - 6
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -121,13 +121,14 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
 					.getAuthorizationExchange().getAuthorizationResponse();
 
 			if (authorizationResponse.statusError()) {
-				throw new OAuth2AuthenticationException(
-						authorizationResponse.getError(), authorizationResponse.getError().toString());
+				return Mono.error(new OAuth2AuthenticationException(
+						authorizationResponse.getError(), authorizationResponse.getError().toString()));
 			}
 
 			if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
 				OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
-				throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+				return Mono.error(new OAuth2AuthenticationException(
+						oauth2Error, oauth2Error.toString()));
 			}
 
 			OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest(
@@ -139,7 +140,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
 					.onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()))
 					.onErrorMap(JwtException.class, e -> {
 						OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, e.getMessage(), null);
-						throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e);
+						return new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e);
 					});
 		});
 	}
@@ -166,7 +167,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
 					INVALID_ID_TOKEN_ERROR_CODE,
 					"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
 					null);
-			throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString());
+			return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()));
 		}
 
 		return createOidcToken(clientRegistration, accessTokenResponse)

+ 6 - 9
web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -106,19 +106,16 @@ public class AuthenticationWebFilter implements WebFilter {
 			.filter( matchResult -> matchResult.isMatch())
 			.flatMap( matchResult -> this.authenticationConverter.convert(exchange))
 			.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
-			.flatMap( token -> authenticate(exchange, chain, token));
+			.flatMap( token -> authenticate(exchange, chain, token))
+			.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
+					.onAuthenticationFailure(new WebFilterExchange(exchange, chain), e));
 	}
 
-	private Mono<Void> authenticate(ServerWebExchange exchange,
-		WebFilterChain chain, Authentication token) {
-		WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain);
-
+	private Mono<Void> authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) {
 		return this.authenticationManagerResolver.resolve(exchange.getRequest())
 			.flatMap(authenticationManager -> authenticationManager.authenticate(token))
 			.switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass()))))
-			.flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange))
-			.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
-				.onAuthenticationFailure(webFilterExchange, e));
+			.flatMap(authentication -> onAuthenticationSuccess(authentication, new WebFilterExchange(exchange, chain)));
 	}
 
 	protected Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) {