Przeglądaj źródła

Default to server_error when OAuth2Error.errorCode is null

Fixes gh-5594
Joe Grandja 7 lat temu
rodzic
commit
e243f93eed

+ 10 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClient.java

@@ -37,6 +37,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.util.CollectionUtils;
 
@@ -111,8 +112,15 @@ public class NimbusAuthorizationCodeTokenResponseClient implements OAuth2AccessT
 		if (!tokenResponse.indicatesSuccess()) {
 			TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse;
 			ErrorObject errorObject = tokenErrorResponse.getErrorObject();
-			OAuth2Error oauth2Error = new OAuth2Error(errorObject.getCode(), errorObject.getDescription(),
-				(errorObject.getURI() != null ? errorObject.getURI().toString() : null));
+			OAuth2Error oauth2Error;
+			if (errorObject == null) {
+				oauth2Error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR);
+			} else {
+				oauth2Error = new OAuth2Error(
+						errorObject.getCode() != null ? errorObject.getCode() : OAuth2ErrorCodes.SERVER_ERROR,
+						errorObject.getDescription(),
+						errorObject.getURI() != null ? errorObject.getURI().toString() : null);
+			}
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 		}
 

+ 22 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java

@@ -214,6 +214,28 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
 		}
 	}
 
+	// gh-5594
+	@Test
+	public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
+		this.exception.expect(OAuth2AuthenticationException.class);
+		this.exception.expectMessage(containsString("server_error"));
+
+		MockWebServer server = new MockWebServer();
+
+		server.enqueue(new MockResponse().setResponseCode(500));
+		server.start();
+
+		String tokenUri = server.url("/oauth2/token").toString();
+		when(this.providerDetails.getTokenUri()).thenReturn(tokenUri);
+
+		try {
+			this.tokenResponseClient.getTokenResponse(
+					new OAuth2AuthorizationCodeGrantRequest(this.clientRegistration, this.authorizationExchange));
+		} finally {
+			server.shutdown();
+		}
+	}
+
 	@Test
 	public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthenticationException() throws Exception {
 		this.exception.expect(OAuth2AuthenticationException.class);

+ 11 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java

@@ -187,6 +187,17 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
 			.hasMessageContaining("unauthorized_client");
 	}
 
+	// gh-5594
+	@Test
+	public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthenticationException() throws Exception {
+		String accessTokenErrorResponse = "{}";
+		this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value()));
+
+		assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block())
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessageContaining("server_error");
+	}
+
 	@Test
 	public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthenticationException() throws Exception {
 		String accessTokenSuccessResponse = "{\n" +

+ 10 - 5
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java

@@ -28,6 +28,7 @@ import org.springframework.http.ReactiveHttpInputMessage;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.web.reactive.function.BodyExtractor;
 import org.springframework.web.reactive.function.BodyExtractors;
@@ -80,11 +81,15 @@ class OAuth2AccessTokenResponseBodyExtractor
 		}
 		TokenErrorResponse tokenErrorResponse = (TokenErrorResponse) tokenResponse;
 		ErrorObject errorObject = tokenErrorResponse.getErrorObject();
-		OAuth2Error oauth2Error = new OAuth2Error(errorObject.getCode(),
-				errorObject.getDescription(), (errorObject.getURI() != null ?
-				errorObject.getURI().toString() :
-				null));
-
+		OAuth2Error oauth2Error;
+		if (errorObject == null) {
+			oauth2Error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR);
+		} else {
+			oauth2Error = new OAuth2Error(
+					errorObject.getCode() != null ? errorObject.getCode() : OAuth2ErrorCodes.SERVER_ERROR,
+					errorObject.getDescription(),
+					errorObject.getURI() != null ? errorObject.getURI().toString() : null);
+		}
 		return Mono.error(new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()));
 	}