浏览代码

Improve OAuth2LoginAuthenticationProvider

1. update OAuth2LoginAuthenticationProvider to use
OAuth2AuthorizationCodeAuthenticationProvider
2. apply fix gh-5368 for OAuth2AuthorizationCodeAuthenticationProvider
to return additionalParameters value from accessTokenResponse

Fixes gh-5633
Ruby Hartono 5 年之前
父节点
当前提交
71b4248fe6

+ 3 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -73,7 +73,8 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 				authorizationCodeAuthentication.getClientRegistration(),
 				authorizationCodeAuthentication.getAuthorizationExchange(),
 				accessTokenResponse.getAccessToken(),
-				accessTokenResponse.getRefreshToken());
+				accessTokenResponse.getRefreshToken(),
+				accessTokenResponse.getAdditionalParameters());
 		authenticationResult.setDetails(authorizationCodeAuthentication.getDetails());
 
 		return authenticationResult;

+ 17 - 23
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -28,7 +28,6 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
-import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.util.Assert;
 
@@ -60,7 +59,7 @@ import java.util.Map;
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response</a>
  */
 public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider {
-	private final OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
+	private final OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider;
 	private final OAuth2UserService<OAuth2UserRequest, OAuth2User> userService;
 	private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
 
@@ -74,59 +73,54 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider
 		OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient,
 		OAuth2UserService<OAuth2UserRequest, OAuth2User> userService) {
 
-		Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null");
 		Assert.notNull(userService, "userService cannot be null");
-		this.accessTokenResponseClient = accessTokenResponseClient;
+		this.authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(accessTokenResponseClient);
 		this.userService = userService;
 	}
 
 	@Override
 	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
-		OAuth2LoginAuthenticationToken authorizationCodeAuthentication =
+		OAuth2LoginAuthenticationToken loginAuthenticationToken =
 			(OAuth2LoginAuthenticationToken) authentication;
 
 		// Section 3.1.2.1 Authentication Request - https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
 		// scope
 		// 		REQUIRED. OpenID Connect requests MUST contain the "openid" scope value.
-		if (authorizationCodeAuthentication.getAuthorizationExchange()
+		if (loginAuthenticationToken.getAuthorizationExchange()
 			.getAuthorizationRequest().getScopes().contains("openid")) {
 			// This is an OpenID Connect Authentication Request so return null
 			// and let OidcAuthorizationCodeAuthenticationProvider handle it instead
 			return null;
 		}
 
-		OAuth2AccessTokenResponse accessTokenResponse;
+		OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthenticationToken;
 		try {
-			OAuth2AuthorizationExchangeValidator.validate(
-					authorizationCodeAuthentication.getAuthorizationExchange());
-
-			accessTokenResponse = this.accessTokenResponseClient.getTokenResponse(
-					new OAuth2AuthorizationCodeGrantRequest(
-							authorizationCodeAuthentication.getClientRegistration(),
-							authorizationCodeAuthentication.getAuthorizationExchange()));
-
+			authorizationCodeAuthenticationToken = (OAuth2AuthorizationCodeAuthenticationToken) this.authorizationCodeAuthenticationProvider
+					.authenticate(new OAuth2AuthorizationCodeAuthenticationToken(
+							loginAuthenticationToken.getClientRegistration(),
+							loginAuthenticationToken.getAuthorizationExchange()));
 		} catch (OAuth2AuthorizationException ex) {
 			OAuth2Error oauth2Error = ex.getError();
 			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 		}
 
-		OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
-		Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
+		OAuth2AccessToken accessToken = authorizationCodeAuthenticationToken.getAccessToken();
+		Map<String, Object> additionalParameters = authorizationCodeAuthenticationToken.getAdditionalParameters();
 
 		OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest(
-				authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters));
+				loginAuthenticationToken.getClientRegistration(), accessToken, additionalParameters));
 
 		Collection<? extends GrantedAuthority> mappedAuthorities =
 			this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
 
 		OAuth2LoginAuthenticationToken authenticationResult = new OAuth2LoginAuthenticationToken(
-			authorizationCodeAuthentication.getClientRegistration(),
-			authorizationCodeAuthentication.getAuthorizationExchange(),
+			loginAuthenticationToken.getClientRegistration(),
+			loginAuthenticationToken.getAuthorizationExchange(),
 			oauth2User,
 			mappedAuthorities,
 			accessToken,
-			accessTokenResponse.getRefreshToken());
-		authenticationResult.setDetails(authorizationCodeAuthentication.getDetails());
+			authorizationCodeAuthenticationToken.getRefreshToken());
+		authenticationResult.setDetails(loginAuthenticationToken.getDetails());
 
 		return authenticationResult;
 	}

+ 25 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,6 +16,8 @@
 package org.springframework.security.oauth2.client.authentication;
 
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
 
 import org.junit.Before;
 import org.junit.Test;
@@ -119,4 +121,26 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		assertThat(authenticationResult.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
 		assertThat(authenticationResult.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
 	}
+
+	// gh-5368
+	@Test
+	public void authenticateWhenAuthorizationSuccessResponseThenAdditionalParametersIncluded() {
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put("param1", "value1");
+		additionalParameters.put("param2", "value2");
+
+		OAuth2AccessTokenResponse accessTokenResponse = accessTokenResponse().additionalParameters(additionalParameters)
+				.build();
+		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
+
+		OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
+				success().build());
+
+		OAuth2AuthorizationCodeAuthenticationToken authentication = (OAuth2AuthorizationCodeAuthenticationToken) this.authenticationProvider
+				.authenticate(
+						new OAuth2AuthorizationCodeAuthenticationToken(this.clientRegistration, authorizationExchange));
+
+		assertThat(authentication.getAdditionalParameters())
+				.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
+	}
 }