Jelajahi Sumber

Add additional parameters to OAuth2UserRequest

Fixes gh-5368
Joe Grandja 7 tahun lalu
induk
melakukan
8a0c6868cd
12 mengubah file dengan 311 tambahan dan 71 penghapusan
  1. 4 2
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java
  2. 4 1
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java
  3. 4 5
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java
  4. 3 3
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java
  5. 20 2
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java
  6. 33 1
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequest.java
  7. 47 11
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java
  8. 25 1
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java
  9. 53 14
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java
  10. 34 0
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java
  11. 47 17
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestTests.java
  12. 37 14
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestTests.java

+ 4 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProvider.java

@@ -30,6 +30,7 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.util.Assert;
 
 import java.util.Collection;
+import java.util.Map;
 
 /**
  * An implementation of an {@link AuthenticationProvider} for OAuth 2.0 Login,
@@ -101,9 +102,10 @@ public class OAuth2LoginAuthenticationProvider implements AuthenticationProvider
 					authorizationCodeAuthentication.getAuthorizationExchange()));
 
 		OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
+		Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
 
-		OAuth2User oauth2User = this.userService.loadUser(
-			new OAuth2UserRequest(authorizationCodeAuthentication.getClientRegistration(), accessToken));
+		OAuth2User oauth2User = this.userService.loadUser(new OAuth2UserRequest(
+				authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters));
 
 		Collection<? extends GrantedAuthority> mappedAuthorities =
 			this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());

+ 4 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManager.java

@@ -16,6 +16,7 @@
 package org.springframework.security.oauth2.client.authentication;
 
 import java.util.Collection;
+import java.util.Map;
 
 import org.springframework.security.authentication.ReactiveAuthenticationManager;
 import org.springframework.security.core.Authentication;
@@ -109,7 +110,9 @@ public class OAuth2LoginReactiveAuthenticationManager implements
 
 	private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
 		OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
-		OAuth2UserRequest userRequest = new OAuth2UserRequest(authorizationCodeAuthentication.getClientRegistration(), accessToken);
+		Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
+		OAuth2UserRequest userRequest = new OAuth2UserRequest(
+				authorizationCodeAuthentication.getClientRegistration(), accessToken, additionalParameters);
 		return this.userService.loadUser(userRequest)
 				.flatMap(oauth2User -> {
 					Collection<? extends GrantedAuthority> mappedAuthorities =

+ 4 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java

@@ -139,19 +139,18 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati
 
 		ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();
 
-		if (!accessTokenResponse.getAdditionalParameters().containsKey(OidcParameterNames.ID_TOKEN)) {
+		Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
+		if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) {
 			OAuth2Error invalidIdTokenError = new OAuth2Error(
 				INVALID_ID_TOKEN_ERROR_CODE,
 				"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
 				null);
 			throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString());
 		}
-
 		OidcIdToken idToken = createOidcToken(clientRegistration, accessTokenResponse);
 
-		OidcUser oidcUser = this.userService.loadUser(
-			new OidcUserRequest(clientRegistration, accessTokenResponse.getAccessToken(), idToken));
-
+		OidcUser oidcUser = this.userService.loadUser(new OidcUserRequest(
+				clientRegistration, accessTokenResponse.getAccessToken(), idToken, additionalParameters));
 		Collection<? extends GrantedAuthority> mappedAuthorities =
 			this.authoritiesMapper.mapAuthorities(oidcUser.getAuthorities());
 

+ 3 - 3
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java

@@ -159,10 +159,10 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
 
 	private Mono<OAuth2AuthenticationToken> authenticationResult(OAuth2LoginAuthenticationToken authorizationCodeAuthentication, OAuth2AccessTokenResponse accessTokenResponse) {
 		OAuth2AccessToken accessToken = accessTokenResponse.getAccessToken();
-
 		ClientRegistration clientRegistration = authorizationCodeAuthentication.getClientRegistration();
+		Map<String, Object> additionalParameters = accessTokenResponse.getAdditionalParameters();
 
-		if (!accessTokenResponse.getAdditionalParameters().containsKey(OidcParameterNames.ID_TOKEN)) {
+		if (!additionalParameters.containsKey(OidcParameterNames.ID_TOKEN)) {
 			OAuth2Error invalidIdTokenError = new OAuth2Error(
 					INVALID_ID_TOKEN_ERROR_CODE,
 					"Missing (required) ID Token in Token Response for Client Registration: " + clientRegistration.getRegistrationId(),
@@ -171,7 +171,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
 		}
 
 		return createOidcToken(clientRegistration, accessTokenResponse)
-				.map(idToken ->  new OidcUserRequest(clientRegistration, accessToken, idToken))
+				.map(idToken ->  new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters))
 				.flatMap(this.userService::loadUser)
 				.flatMap(oauth2User -> {
 					Collection<? extends GrantedAuthority> mappedAuthorities =

+ 20 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequest.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -21,6 +21,9 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.util.Assert;
 
+import java.util.Collections;
+import java.util.Map;
+
 /**
  * Represents a request the {@link OidcUserService} uses
  * when initiating a request to the UserInfo Endpoint.
@@ -45,7 +48,22 @@ public class OidcUserRequest extends OAuth2UserRequest {
 	public OidcUserRequest(ClientRegistration clientRegistration,
 							OAuth2AccessToken accessToken, OidcIdToken idToken) {
 
-		super(clientRegistration, accessToken);
+		this(clientRegistration, accessToken, idToken, Collections.emptyMap());
+	}
+
+	/**
+	 * Constructs an {@code OidcUserRequest} using the provided parameters.
+	 *
+	 * @since 5.1
+	 * @param clientRegistration the client registration
+	 * @param accessToken the access token credential
+	 * @param idToken the ID Token
+	 * @param additionalParameters the additional parameters, may be empty
+	 */
+	public OidcUserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken,
+							OidcIdToken idToken, Map<String, Object> additionalParameters) {
+
+		super(clientRegistration, accessToken, additionalParameters);
 		Assert.notNull(idToken, "idToken cannot be null");
 		this.idToken = idToken;
 	}

+ 33 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequest.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,6 +18,11 @@ package org.springframework.security.oauth2.client.userinfo;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
 
 /**
  * Represents a request the {@link OAuth2UserService} uses
@@ -32,6 +37,7 @@ import org.springframework.util.Assert;
 public class OAuth2UserRequest {
 	private final ClientRegistration clientRegistration;
 	private final OAuth2AccessToken accessToken;
+	private final Map<String, Object> additionalParameters;
 
 	/**
 	 * Constructs an {@code OAuth2UserRequest} using the provided parameters.
@@ -40,10 +46,26 @@ public class OAuth2UserRequest {
 	 * @param accessToken the access token
 	 */
 	public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken) {
+		this(clientRegistration, accessToken, Collections.emptyMap());
+	}
+
+	/**
+	 * Constructs an {@code OAuth2UserRequest} using the provided parameters.
+	 *
+	 * @since 5.1
+	 * @param clientRegistration the client registration
+	 * @param accessToken the access token
+	 * @param additionalParameters the additional parameters, may be empty
+	 */
+	public OAuth2UserRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken,
+								Map<String, Object> additionalParameters) {
 		Assert.notNull(clientRegistration, "clientRegistration cannot be null");
 		Assert.notNull(accessToken, "accessToken cannot be null");
 		this.clientRegistration = clientRegistration;
 		this.accessToken = accessToken;
+		this.additionalParameters = Collections.unmodifiableMap(
+				CollectionUtils.isEmpty(additionalParameters) ?
+				Collections.emptyMap() : new LinkedHashMap<>(additionalParameters));
 	}
 
 	/**
@@ -63,4 +85,14 @@ public class OAuth2UserRequest {
 	public OAuth2AccessToken getAccessToken() {
 		return this.accessToken;
 	}
+
+	/**
+	 * Returns the additional parameters that may be used in the request.
+	 *
+	 * @since 5.1
+	 * @return a {@code Map} of the additional parameters, may be empty.
+	 */
+	public Map<String, Object> getAdditionalParameters() {
+		return this.additionalParameters;
+	}
 }

+ 47 - 11
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java

@@ -20,6 +20,7 @@ import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.stubbing.Answer;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
@@ -35,17 +36,20 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
-import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 
+import java.time.Instant;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.LinkedHashSet;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.hamcrest.CoreMatchers.containsString;
@@ -164,11 +168,7 @@ public class OAuth2LoginAuthenticationProviderTests {
 
 	@Test
 	public void authenticateWhenLoginSuccessThenReturnAuthentication() {
-		OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
-		OAuth2RefreshToken refreshToken = mock(OAuth2RefreshToken.class);
-		OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
-		when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
-		when(accessTokenResponse.getRefreshToken()).thenReturn(refreshToken);
+		OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
 		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
 
 		OAuth2User principal = mock(OAuth2User.class);
@@ -187,15 +187,13 @@ public class OAuth2LoginAuthenticationProviderTests {
 		assertThat(authentication.getAuthorities()).isEqualTo(authorities);
 		assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
 		assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
-		assertThat(authentication.getAccessToken()).isEqualTo(accessToken);
-		assertThat(authentication.getRefreshToken()).isEqualTo(refreshToken);
+		assertThat(authentication.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
+		assertThat(authentication.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
 	}
 
 	@Test
 	public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() {
-		OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
-		OAuth2AccessTokenResponse accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
-		when(accessTokenResponse.getAccessToken()).thenReturn(accessToken);
+		OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
 		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
 
 		OAuth2User principal = mock(OAuth2User.class);
@@ -216,4 +214,42 @@ public class OAuth2LoginAuthenticationProviderTests {
 
 		assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
 	}
+
+	// gh-5368
+	@Test
+	public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
+		OAuth2AccessTokenResponse accessTokenResponse = this.accessTokenSuccessResponse();
+		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
+
+		OAuth2User principal = mock(OAuth2User.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		when(principal.getAuthorities()).thenAnswer(
+				(Answer<List<GrantedAuthority>>) invocation -> authorities);
+		ArgumentCaptor<OAuth2UserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class);
+		when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal);
+
+		this.authenticationProvider.authenticate(
+				new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+
+		assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf(
+				accessTokenResponse.getAdditionalParameters());
+	}
+
+	private OAuth2AccessTokenResponse accessTokenSuccessResponse() {
+		Instant expiresAt = Instant.now().plusSeconds(5);
+		Set<String> scopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put("param1", "value1");
+		additionalParameters.put("param2", "value2");
+
+		return OAuth2AccessTokenResponse
+				.withToken("access-token-1234")
+				.tokenType(OAuth2AccessToken.TokenType.BEARER)
+				.expiresIn(expiresAt.getEpochSecond())
+				.scopes(scopes)
+				.refreshToken("refresh-token-1234")
+				.additionalParameters(additionalParameters)
+				.build();
+
+	}
 }

+ 25 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginReactiveAuthenticationManagerTests.java

@@ -23,11 +23,14 @@ import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.when;
 
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
 
 import org.junit.Before;
 import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 import org.springframework.security.authentication.TestingAuthenticationToken;
@@ -164,7 +167,7 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
 	}
 
 	@Test
-	public void authenticationWhenOAuth2UserNotFoundThenSuccess() {
+	public void authenticationWhenOAuth2UserFoundThenSuccess() {
 		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
 				.tokenType(OAuth2AccessToken.TokenType.BEARER)
 				.build();
@@ -179,6 +182,27 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
 		assertThat(result.isAuthenticated()).isTrue();
 	}
 
+	// gh-5368
+	@Test
+	public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put("param1", "value1");
+		additionalParameters.put("param2", "value2");
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
+				.tokenType(OAuth2AccessToken.TokenType.BEARER)
+				.additionalParameters(additionalParameters)
+				.build();
+		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
+		DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections.singletonMap("user", "rob"), "user");
+		ArgumentCaptor<OAuth2UserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OAuth2UserRequest.class);
+		when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user));
+
+		this.manager.authenticate(loginToken()).block();
+
+		assertThat(userRequestArgCaptor.getValue().getAdditionalParameters())
+				.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
+	}
+
 	private OAuth2LoginAuthenticationToken loginToken() {
 		ClientRegistration clientRegistration = this.registration.build();
 		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest

+ 53 - 14
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java

@@ -20,6 +20,7 @@ import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.stubbing.Answer;
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
@@ -37,7 +38,6 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
-import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
@@ -55,6 +55,7 @@ import java.util.HashMap;
 import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.hamcrest.CoreMatchers.containsString;
@@ -78,8 +79,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 	private OAuth2AuthorizationExchange authorizationExchange;
 	private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
 	private OAuth2AccessTokenResponse accessTokenResponse;
-	private OAuth2AccessToken accessToken;
-	private OAuth2RefreshToken refreshToken;
 	private OAuth2UserService<OidcUserRequest, OidcUser> userService;
 	private OidcAuthorizationCodeAuthenticationProvider authenticationProvider;
 
@@ -95,9 +94,7 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 		this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
 		this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
 		this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
-		this.accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
-		this.accessToken = mock(OAuth2AccessToken.class);
-		this.refreshToken = mock(OAuth2RefreshToken.class);
+		this.accessTokenResponse = this.accessTokenSuccessResponse();
 		this.userService = mock(OAuth2UserService.class);
 		this.authenticationProvider = PowerMockito.spy(
 			new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService));
@@ -111,11 +108,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 		when(this.authorizationResponse.getState()).thenReturn("12345");
 		when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
 		when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com");
-		when(this.accessTokenResponse.getAccessToken()).thenReturn(this.accessToken);
-		when(this.accessTokenResponse.getRefreshToken()).thenReturn(this.refreshToken);
-		Map<String, Object> additionalParameters = new HashMap<>();
-		additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
-		when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(additionalParameters);
 		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(this.accessTokenResponse);
 	}
 
@@ -194,7 +186,11 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 		this.exception.expect(OAuth2AuthenticationException.class);
 		this.exception.expectMessage(containsString("invalid_id_token"));
 
-		when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(Collections.emptyMap());
+		OAuth2AccessTokenResponse accessTokenResponse =
+				OAuth2AccessTokenResponse.withResponse(this.accessTokenSuccessResponse())
+						.additionalParameters(Collections.emptyMap())
+						.build();
+		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
 
 		this.authenticationProvider.authenticate(
 			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
@@ -368,8 +364,8 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 		assertThat(authentication.getAuthorities()).isEqualTo(authorities);
 		assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
 		assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
-		assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken);
-		assertThat(authentication.getRefreshToken()).isEqualTo(this.refreshToken);
+		assertThat(authentication.getAccessToken()).isEqualTo(this.accessTokenResponse.getAccessToken());
+		assertThat(authentication.getRefreshToken()).isEqualTo(this.accessTokenResponse.getRefreshToken());
 	}
 
 	@Test
@@ -400,6 +396,30 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 		assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
 	}
 
+	// gh-5368
+	@Test
+	public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() throws Exception {
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
+		claims.put(IdTokenClaimNames.SUB, "subject1");
+		claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
+		claims.put(IdTokenClaimNames.AZP, "client1");
+		this.setUpIdToken(claims);
+
+		OidcUser principal = mock(OidcUser.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		when(principal.getAuthorities()).thenAnswer(
+				(Answer<List<GrantedAuthority>>) invocation -> authorities);
+		ArgumentCaptor<OidcUserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class);
+		when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal);
+
+		this.authenticationProvider.authenticate(new OAuth2LoginAuthenticationToken(
+				this.clientRegistration, this.authorizationExchange));
+
+		assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf(
+				this.accessTokenResponse.getAdditionalParameters());
+	}
+
 	private void setUpIdToken(Map<String, Object> claims) throws Exception {
 		Instant issuedAt = Instant.now();
 		Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
@@ -416,4 +436,23 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 		when(jwtDecoder.decode(anyString())).thenReturn(idToken);
 		PowerMockito.doReturn(jwtDecoder).when(this.authenticationProvider, "getJwtDecoder", any(ClientRegistration.class));
 	}
+
+	private OAuth2AccessTokenResponse accessTokenSuccessResponse() {
+		Instant expiresAt = Instant.now().plusSeconds(5);
+		Set<String> scopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email"));
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put("param1", "value1");
+		additionalParameters.put("param2", "value2");
+		additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
+
+		return OAuth2AccessTokenResponse
+				.withToken("access-token-1234")
+				.tokenType(OAuth2AccessToken.TokenType.BEARER)
+				.expiresIn(expiresAt.getEpochSecond())
+				.scopes(scopes)
+				.refreshToken("refresh-token-1234")
+				.additionalParameters(additionalParameters)
+				.build();
+
+	}
 }

+ 34 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java

@@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.oidc.authentication;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 import org.springframework.security.authentication.TestingAuthenticationToken;
@@ -217,6 +218,39 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
 		assertThat(result.isAuthenticated()).isTrue();
 	}
 
+	// gh-5368
+	@Test
+	public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue());
+		additionalParameters.put("param1", "value1");
+		additionalParameters.put("param2", "value2");
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
+				.tokenType(OAuth2AccessToken.TokenType.BEARER)
+				.additionalParameters(additionalParameters)
+				.build();
+
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
+		claims.put(IdTokenClaimNames.SUB, "rob");
+		claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId"));
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
+		Jwt idToken = new Jwt("id-token", issuedAt, expiresAt, claims, claims);
+
+		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
+		DefaultOidcUser user = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"), this.idToken);
+		ArgumentCaptor<OidcUserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class);
+		when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(Mono.just(user));
+		when(this.jwtDecoder.decode(any())).thenReturn(Mono.just(idToken));
+		this.manager.setDecoderFactory(c -> this.jwtDecoder);
+
+		this.manager.authenticate(loginToken()).block();
+
+		assertThat(userRequestArgCaptor.getValue().getAdditionalParameters())
+				.containsAllEntriesOf(accessTokenResponse.getAdditionalParameters());
+	}
+
 	private OAuth2LoginAuthenticationToken loginToken() {
 		ClientRegistration clientRegistration = this.registration.build();
 		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest

+ 47 - 17
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,57 +17,87 @@ package org.springframework.security.oauth2.client.oidc.userinfo;
 
 import org.junit.Before;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.mockito.Mockito.mock;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /**
  * Tests for {@link OidcUserRequest}.
  *
  * @author Joe Grandja
  */
-@RunWith(PowerMockRunner.class)
-@PrepareForTest(ClientRegistration.class)
 public class OidcUserRequestTests {
 	private ClientRegistration clientRegistration;
 	private OAuth2AccessToken accessToken;
 	private OidcIdToken idToken;
+	private Map<String, Object> additionalParameters;
 
 	@Before
 	public void setUp() {
-		this.clientRegistration = mock(ClientRegistration.class);
-		this.accessToken = mock(OAuth2AccessToken.class);
-		this.idToken = mock(OidcIdToken.class);
+		this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
+				.clientId("client-1")
+				.clientSecret("secret")
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.redirectUriTemplate("https://client.com")
+				.scope(new LinkedHashSet<>(Arrays.asList("openid", "profile")))
+				.authorizationUri("https://provider.com/oauth2/authorization")
+				.tokenUri("https://provider.com/oauth2/token")
+				.jwkSetUri("https://provider.com/keys")
+				.clientName("Client 1")
+				.build();
+		this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				"access-token-1234", Instant.now(), Instant.now().plusSeconds(60),
+				new LinkedHashSet<>(Arrays.asList("scope1", "scope2")));
+		Map<String, Object> claims = new HashMap<>();
+		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
+		claims.put(IdTokenClaimNames.SUB, "subject1");
+		claims.put(IdTokenClaimNames.AZP, "client-1");
+		this.idToken = new OidcIdToken("id-token-1234", Instant.now(),
+				Instant.now().plusSeconds(3600), claims);
+		this.additionalParameters = new HashMap<>();
+		this.additionalParameters.put("param1", "value1");
+		this.additionalParameters.put("param2", "value2");
 	}
 
-	@Test(expected = IllegalArgumentException.class)
+	@Test
 	public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
-		new OidcUserRequest(null, this.accessToken, this.idToken);
+		assertThatThrownBy(() -> new OidcUserRequest(null, this.accessToken, this.idToken))
+				.isInstanceOf(IllegalArgumentException.class);
 	}
 
-	@Test(expected = IllegalArgumentException.class)
+	@Test
 	public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
-		new OidcUserRequest(this.clientRegistration, null, this.idToken);
+		assertThatThrownBy(() -> new OidcUserRequest(this.clientRegistration, null, this.idToken))
+				.isInstanceOf(IllegalArgumentException.class);
 	}
 
-	@Test(expected = IllegalArgumentException.class)
+	@Test
 	public void constructorWhenIdTokenIsNullThenThrowIllegalArgumentException() {
-		new OidcUserRequest(this.clientRegistration, this.accessToken, null);
+		assertThatThrownBy(() -> new OidcUserRequest(this.clientRegistration, this.accessToken, null))
+				.isInstanceOf(IllegalArgumentException.class);
 	}
 
 	@Test
 	public void constructorWhenAllParametersProvidedAndValidThenCreated() {
 		OidcUserRequest userRequest = new OidcUserRequest(
-			this.clientRegistration, this.accessToken, this.idToken);
+			this.clientRegistration, this.accessToken, this.idToken, this.additionalParameters);
 
 		assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
 		assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken);
 		assertThat(userRequest.getIdToken()).isEqualTo(this.idToken);
+		assertThat(userRequest.getAdditionalParameters()).containsAllEntriesOf(this.additionalParameters);
 	}
 }

+ 37 - 14
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/OAuth2UserRequestTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,47 +17,70 @@ package org.springframework.security.oauth2.client.userinfo;
 
 import org.junit.Before;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedHashSet;
+import java.util.Map;
+
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.mockito.Mockito.mock;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /**
  * Tests for {@link OAuth2UserRequest}.
  *
  * @author Joe Grandja
  */
-@RunWith(PowerMockRunner.class)
-@PrepareForTest(ClientRegistration.class)
 public class OAuth2UserRequestTests {
 	private ClientRegistration clientRegistration;
 	private OAuth2AccessToken accessToken;
+	private Map<String, Object> additionalParameters;
 
 	@Before
 	public void setUp() {
-		this.clientRegistration = mock(ClientRegistration.class);
-		this.accessToken = mock(OAuth2AccessToken.class);
+		this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
+				.clientId("client-1")
+				.clientSecret("secret")
+				.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.redirectUriTemplate("https://client.com")
+				.scope(new LinkedHashSet<>(Arrays.asList("scope1", "scope2")))
+				.authorizationUri("https://provider.com/oauth2/authorization")
+				.tokenUri("https://provider.com/oauth2/token")
+				.clientName("Client 1")
+				.build();
+		this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				"access-token-1234", Instant.now(), Instant.now().plusSeconds(60),
+				new LinkedHashSet<>(Arrays.asList("scope1", "scope2")));
+		this.additionalParameters = new HashMap<>();
+		this.additionalParameters.put("param1", "value1");
+		this.additionalParameters.put("param2", "value2");
 	}
 
-	@Test(expected = IllegalArgumentException.class)
+	@Test
 	public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() {
-		new OAuth2UserRequest(null, this.accessToken);
+		assertThatThrownBy(() -> new OAuth2UserRequest(null, this.accessToken))
+				.isInstanceOf(IllegalArgumentException.class);
 	}
 
-	@Test(expected = IllegalArgumentException.class)
+	@Test
 	public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() {
-		new OAuth2UserRequest(this.clientRegistration, null);
+		assertThatThrownBy(() -> new OAuth2UserRequest(this.clientRegistration, null))
+				.isInstanceOf(IllegalArgumentException.class);
 	}
 
 	@Test
 	public void constructorWhenAllParametersProvidedAndValidThenCreated() {
-		OAuth2UserRequest userRequest = new OAuth2UserRequest(this.clientRegistration, this.accessToken);
+		OAuth2UserRequest userRequest = new OAuth2UserRequest(
+				this.clientRegistration, this.accessToken, this.additionalParameters);
 
 		assertThat(userRequest.getClientRegistration()).isEqualTo(this.clientRegistration);
 		assertThat(userRequest.getAccessToken()).isEqualTo(this.accessToken);
+		assertThat(userRequest.getAdditionalParameters()).containsAllEntriesOf(this.additionalParameters);
 	}
 }