Browse Source

Propagate additional token request parameters

Closes gh-226
Joe Grandja 4 years ago
parent
commit
7652d0ebbe
9 changed files with 115 additions and 94 deletions
  1. 2 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationGrantAuthenticationToken.java
  2. 9 15
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationToken.java
  3. 8 14
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationToken.java
  4. 31 7
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java
  5. 10 6
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java
  6. 11 13
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationTokenTests.java
  7. 11 11
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java
  8. 11 16
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java
  9. 22 11
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java

+ 2 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationGrantAuthenticationToken.java

@@ -16,6 +16,7 @@
 package org.springframework.security.oauth2.server.authorization.authentication;
 
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.Map;
 
 import org.springframework.lang.Nullable;
@@ -57,7 +58,7 @@ public class OAuth2AuthorizationGrantAuthenticationToken extends AbstractAuthent
 		this.clientPrincipal = clientPrincipal;
 		this.additionalParameters = Collections.unmodifiableMap(
 				additionalParameters != null ?
-						additionalParameters :
+						new HashMap<>(additionalParameters) :
 						Collections.emptyMap());
 	}
 

+ 9 - 15
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationToken.java

@@ -16,12 +16,13 @@
 package org.springframework.security.oauth2.server.authorization.authentication;
 
 import java.util.Collections;
-import java.util.LinkedHashSet;
+import java.util.HashSet;
+import java.util.Map;
 import java.util.Set;
 
+import org.springframework.lang.Nullable;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.util.Assert;
 
 /**
  * An {@link Authentication} implementation used for the OAuth 2.0 Client Credentials Grant.
@@ -34,25 +35,18 @@ import org.springframework.util.Assert;
 public class OAuth2ClientCredentialsAuthenticationToken extends OAuth2AuthorizationGrantAuthenticationToken {
 	private final Set<String> scopes;
 
-	/**
-	 * Constructs an {@code OAuth2ClientCredentialsAuthenticationToken} using the provided parameters.
-	 *
-	 * @param clientPrincipal the authenticated client principal
-	 */
-	public OAuth2ClientCredentialsAuthenticationToken(Authentication clientPrincipal) {
-		this(clientPrincipal, Collections.emptySet());
-	}
-
 	/**
 	 * Constructs an {@code OAuth2ClientCredentialsAuthenticationToken} using the provided parameters.
 	 *
 	 * @param clientPrincipal the authenticated client principal
 	 * @param scopes the requested scope(s)
+	 * @param additionalParameters the additional parameters
 	 */
-	public OAuth2ClientCredentialsAuthenticationToken(Authentication clientPrincipal, Set<String> scopes) {
-		super(AuthorizationGrantType.CLIENT_CREDENTIALS, clientPrincipal, null);
-		Assert.notNull(scopes, "scopes cannot be null");
-		this.scopes = Collections.unmodifiableSet(new LinkedHashSet<>(scopes));
+	public OAuth2ClientCredentialsAuthenticationToken(Authentication clientPrincipal,
+			@Nullable Set<String> scopes, @Nullable Map<String, Object> additionalParameters) {
+		super(AuthorizationGrantType.CLIENT_CREDENTIALS, clientPrincipal, additionalParameters);
+		this.scopes = Collections.unmodifiableSet(
+				scopes != null ? new HashSet<>(scopes) : Collections.emptySet());
 	}
 
 	/**

+ 8 - 14
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationToken.java

@@ -16,8 +16,11 @@
 package org.springframework.security.oauth2.server.authorization.authentication;
 
 import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
 import java.util.Set;
 
+import org.springframework.lang.Nullable;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.util.Assert;
@@ -34,30 +37,21 @@ public class OAuth2RefreshTokenAuthenticationToken extends OAuth2AuthorizationGr
 	private final String refreshToken;
 	private final Set<String> scopes;
 
-	/**
-	 * Constructs an {@code OAuth2RefreshTokenAuthenticationToken} using the provided parameters.
-	 *
-	 * @param refreshToken the refresh token
-	 * @param clientPrincipal the authenticated client principal
-	 */
-	public OAuth2RefreshTokenAuthenticationToken(String refreshToken, Authentication clientPrincipal) {
-		this(refreshToken, clientPrincipal, Collections.emptySet());
-	}
-
 	/**
 	 * Constructs an {@code OAuth2RefreshTokenAuthenticationToken} using the provided parameters.
 	 *
 	 * @param refreshToken the refresh token
 	 * @param clientPrincipal the authenticated client principal
 	 * @param scopes the requested scope(s)
+	 * @param additionalParameters the additional parameters
 	 */
 	public OAuth2RefreshTokenAuthenticationToken(String refreshToken, Authentication clientPrincipal,
-			Set<String> scopes) {
-		super(AuthorizationGrantType.REFRESH_TOKEN, clientPrincipal, null);
+			@Nullable Set<String> scopes, @Nullable Map<String, Object> additionalParameters) {
+		super(AuthorizationGrantType.REFRESH_TOKEN, clientPrincipal, additionalParameters);
 		Assert.hasText(refreshToken, "refreshToken cannot be empty");
-		Assert.notNull(scopes, "scopes cannot be null");
 		this.refreshToken = refreshToken;
-		this.scopes = scopes;
+		this.scopes = Collections.unmodifiableSet(
+				scopes != null ? new HashSet<>(scopes) : Collections.emptySet());
 	}
 
 	/**

+ 31 - 7
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java

@@ -229,6 +229,7 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 				throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
 			}
 
+			// @formatter:off
 			Map<String, Object> additionalParameters = parameters
 					.entrySet()
 					.stream()
@@ -237,8 +238,10 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 							!e.getKey().equals(OAuth2ParameterNames.CODE) &&
 							!e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI))
 					.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0)));
+			// @formatter:on
 
-			return new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri, additionalParameters);
+			return new OAuth2AuthorizationCodeAuthenticationToken(
+					code, clientPrincipal, redirectUri, additionalParameters);
 		}
 	}
 
@@ -269,13 +272,24 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 					parameters.get(OAuth2ParameterNames.SCOPE).size() != 1) {
 				throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.SCOPE);
 			}
+			Set<String> requestedScopes = null;
 			if (StringUtils.hasText(scope)) {
-				Set<String> requestedScopes = new HashSet<>(
+				requestedScopes = new HashSet<>(
 						Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
-				return new OAuth2RefreshTokenAuthenticationToken(refreshToken, clientPrincipal, requestedScopes);
 			}
 
-			return new OAuth2RefreshTokenAuthenticationToken(refreshToken, clientPrincipal);
+			// @formatter:off
+			Map<String, Object> additionalParameters = parameters
+					.entrySet()
+					.stream()
+					.filter(e -> !e.getKey().equals(OAuth2ParameterNames.GRANT_TYPE) &&
+							!e.getKey().equals(OAuth2ParameterNames.REFRESH_TOKEN) &&
+							!e.getKey().equals(OAuth2ParameterNames.SCOPE))
+					.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0)));
+			// @formatter:on
+
+			return new OAuth2RefreshTokenAuthenticationToken(
+					refreshToken, clientPrincipal, requestedScopes, additionalParameters);
 		}
 	}
 
@@ -299,13 +313,23 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 					parameters.get(OAuth2ParameterNames.SCOPE).size() != 1) {
 				throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.SCOPE);
 			}
+			Set<String> requestedScopes = null;
 			if (StringUtils.hasText(scope)) {
-				Set<String> requestedScopes = new HashSet<>(
+				requestedScopes = new HashSet<>(
 						Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
-				return new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScopes);
 			}
 
-			return new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal);
+			// @formatter:off
+			Map<String, Object> additionalParameters = parameters
+					.entrySet()
+					.stream()
+					.filter(e -> !e.getKey().equals(OAuth2ParameterNames.GRANT_TYPE) &&
+							!e.getKey().equals(OAuth2ParameterNames.SCOPE))
+					.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0)));
+			// @formatter:on
+
+			return new OAuth2ClientCredentialsAuthenticationToken(
+					clientPrincipal, requestedScopes, additionalParameters);
 		}
 	}
 }

+ 10 - 6
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java

@@ -108,7 +108,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
 		TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken(
 				registeredClient.getClientId(), registeredClient.getClientSecret());
-		OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal);
+		OAuth2ClientCredentialsAuthenticationToken authentication =
+				new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -122,7 +123,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
 				registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null);
-		OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal);
+		OAuth2ClientCredentialsAuthenticationToken authentication =
+				new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -137,7 +139,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 				.authorizationGrantTypes(grantTypes -> grantTypes.remove(AuthorizationGrantType.CLIENT_CREDENTIALS))
 				.build();
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
-		OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal);
+		OAuth2ClientCredentialsAuthenticationToken authentication =
+				new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -151,7 +154,7 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(
-				clientPrincipal, Collections.singleton("invalid-scope"));
+				clientPrincipal, Collections.singleton("invalid-scope"), null);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -166,7 +169,7 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		Set<String> requestedScope = Collections.singleton("scope1");
 		OAuth2ClientCredentialsAuthenticationToken authentication =
-				new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope);
+				new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope, null);
 
 		when(this.jwtEncoder.encode(any(), any()))
 				.thenReturn(createJwt(Collections.singleton("mapped-scoped")));
@@ -180,7 +183,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 	public void authenticateWhenValidAuthenticationThenReturnAccessToken() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
-		OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal);
+		OAuth2ClientCredentialsAuthenticationToken authentication =
+				new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null);
 
 		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(registeredClient.getScopes()));
 

+ 11 - 13
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationTokenTests.java

@@ -16,6 +16,7 @@
 package org.springframework.security.oauth2.server.authorization.authentication;
 
 import java.util.Collections;
+import java.util.Map;
 import java.util.Set;
 
 import org.junit.Test;
@@ -34,42 +35,39 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
 public class OAuth2ClientCredentialsAuthenticationTokenTests {
 	private final OAuth2ClientAuthenticationToken clientPrincipal =
 			new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build());
+	private Set<String> scopes = Collections.singleton("scope1");
+	private Map<String, Object> additionalParameters = Collections.singletonMap("param1", "value1");
 
 	@Test
 	public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationToken(null))
+		assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationToken(null, this.scopes, this.additionalParameters))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("clientPrincipal cannot be null");
 	}
 
-	@Test
-	public void constructorWhenScopesNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationToken(this.clientPrincipal, null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("scopes cannot be null");
-	}
-
 	@Test
 	public void constructorWhenClientPrincipalProvidedThenCreated() {
-		OAuth2ClientCredentialsAuthenticationToken authentication =
-				new OAuth2ClientCredentialsAuthenticationToken(this.clientPrincipal);
+		OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(
+				this.clientPrincipal, this.scopes, this.additionalParameters);
 
 		assertThat(authentication.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS);
 		assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
 		assertThat(authentication.getCredentials().toString()).isEmpty();
-		assertThat(authentication.getScopes()).isEmpty();
+		assertThat(authentication.getScopes()).isEqualTo(this.scopes);
+		assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additionalParameters);
 	}
 
 	@Test
 	public void constructorWhenScopesProvidedThenCreated() {
 		Set<String> expectedScopes = Collections.singleton("test-scope");
 
-		OAuth2ClientCredentialsAuthenticationToken authentication =
-				new OAuth2ClientCredentialsAuthenticationToken(this.clientPrincipal, expectedScopes);
+		OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(
+				this.clientPrincipal, expectedScopes, this.additionalParameters);
 
 		assertThat(authentication.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS);
 		assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
 		assertThat(authentication.getCredentials().toString()).isEmpty();
 		assertThat(authentication.getScopes()).isEqualTo(expectedScopes);
+		assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additionalParameters);
 	}
 }

+ 11 - 11
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java

@@ -124,7 +124,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
 
 		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@@ -169,7 +169,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
 
 		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@@ -199,7 +199,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		Set<String> requestedScopes = new HashSet<>(authorizedScopes);
 		requestedScopes.remove("scope1");
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes, null);
 
 		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@@ -221,7 +221,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		Set<String> requestedScopes = new HashSet<>(authorizedScopes);
 		requestedScopes.add("unauthorized");
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes, null);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -235,7 +235,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				"invalid", clientPrincipal);
+				"invalid", clientPrincipal, null, null);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -250,7 +250,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken(
 				registeredClient.getClientId(), registeredClient.getClientSecret());
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				"refresh-token", clientPrincipal);
+				"refresh-token", clientPrincipal, null, null);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -265,7 +265,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
 				registeredClient.getClientId(), registeredClient.getClientSecret(), ClientAuthenticationMethod.BASIC, null);
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				"refresh-token", clientPrincipal);
+				"refresh-token", clientPrincipal, null, null);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -286,7 +286,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
 				TestRegisteredClients.registeredClient2().build());
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -308,7 +308,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -331,7 +331,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -355,7 +355,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)

+ 11 - 16
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationTokenTests.java

@@ -15,8 +15,8 @@
  */
 package org.springframework.security.oauth2.server.authorization.authentication;
 
-import java.util.Arrays;
-import java.util.HashSet;
+import java.util.Collections;
+import java.util.Map;
 import java.util.Set;
 
 import org.junit.Test;
@@ -34,42 +34,37 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
  * @since 0.0.3
  */
 public class OAuth2RefreshTokenAuthenticationTokenTests {
-	private final OAuth2ClientAuthenticationToken clientPrincipal =
+	private OAuth2ClientAuthenticationToken clientPrincipal =
 			new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build());
+	private Set<String> scopes = Collections.singleton("scope1");
+	private Map<String, Object> additionalParameters = Collections.singletonMap("param1", "value1");
 
 	@Test
 	public void constructorWhenRefreshTokenNullOrEmptyThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken(null, this.clientPrincipal))
+		assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken(null, this.clientPrincipal, this.scopes, this.additionalParameters))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("refreshToken cannot be empty");
-		assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("", this.clientPrincipal))
+		assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("", this.clientPrincipal, this.scopes, this.additionalParameters))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("refreshToken cannot be empty");
 	}
 
 	@Test
 	public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("refresh-token", null))
+		assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("refresh-token", null, this.scopes, this.additionalParameters))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("clientPrincipal cannot be null");
 	}
 
-	@Test
-	public void constructorWhenScopesNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationToken("refresh-token", this.clientPrincipal, null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("scopes cannot be null");
-	}
-
 	@Test
 	public void constructorWhenScopesProvidedThenCreated() {
-		Set<String> expectedScopes = new HashSet<>(Arrays.asList("scope-a", "scope-b"));
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				"refresh-token", this.clientPrincipal, expectedScopes);
+				"refresh-token", this.clientPrincipal, this.scopes, this.additionalParameters);
 		assertThat(authentication.getGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
 		assertThat(authentication.getRefreshToken()).isEqualTo("refresh-token");
 		assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
 		assertThat(authentication.getCredentials().toString()).isEmpty();
-		assertThat(authentication.getScopes()).isEqualTo(expectedScopes);
+		assertThat(authentication.getScopes()).isEqualTo(this.scopes);
+		assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additionalParameters);
 	}
 }

+ 22 - 11
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,10 +15,22 @@
  */
 package org.springframework.security.oauth2.server.authorization.web;
 
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
+
+import javax.servlet.FilterChain;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
+
 import org.springframework.http.HttpStatus;
 import org.springframework.http.converter.HttpMessageConverter;
 import org.springframework.mock.http.client.MockClientHttpResponse;
@@ -47,16 +59,6 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.util.StringUtils;
 
-import javax.servlet.FilterChain;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import java.time.Duration;
-import java.time.Instant;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.Map;
-
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.Assertions.entry;
@@ -232,6 +234,8 @@ public class OAuth2TokenEndpointFilterTests {
 		assertThat(authorizationCodeAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
 		assertThat(authorizationCodeAuthentication.getRedirectUri()).isEqualTo(
 				request.getParameter(OAuth2ParameterNames.REDIRECT_URI));
+		assertThat(authorizationCodeAuthentication.getAdditionalParameters())
+				.containsExactly(entry("custom-param-1", "custom-value-1"));
 
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
 		OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);
@@ -292,6 +296,8 @@ public class OAuth2TokenEndpointFilterTests {
 				clientCredentialsAuthenticationCaptor.getValue();
 		assertThat(clientCredentialsAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
 		assertThat(clientCredentialsAuthentication.getScopes()).isEqualTo(registeredClient.getScopes());
+		assertThat(clientCredentialsAuthentication.getAdditionalParameters())
+				.containsExactly(entry("custom-param-1", "custom-value-1"));
 
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
 		OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);
@@ -372,6 +378,8 @@ public class OAuth2TokenEndpointFilterTests {
 		assertThat(refreshTokenAuthenticationToken.getRefreshToken()).isEqualTo(refreshToken.getTokenValue());
 		assertThat(refreshTokenAuthenticationToken.getPrincipal()).isEqualTo(clientPrincipal);
 		assertThat(refreshTokenAuthenticationToken.getScopes()).isEqualTo(registeredClient.getScopes());
+		assertThat(refreshTokenAuthenticationToken.getAdditionalParameters())
+				.containsExactly(entry("custom-param-1", "custom-value-1"));
 
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
 		OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);
@@ -429,6 +437,7 @@ public class OAuth2TokenEndpointFilterTests {
 		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]);
 		// The client does not need to send the client ID param, but we are resilient in case they do
 		request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
+		request.addParameter("custom-param-1", "custom-value-1");
 
 		return request;
 	}
@@ -441,6 +450,7 @@ public class OAuth2TokenEndpointFilterTests {
 		request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
 		request.addParameter(OAuth2ParameterNames.SCOPE,
 				StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
+		request.addParameter("custom-param-1", "custom-value-1");
 
 		return request;
 	}
@@ -454,6 +464,7 @@ public class OAuth2TokenEndpointFilterTests {
 		request.addParameter(OAuth2ParameterNames.REFRESH_TOKEN, "refresh-token");
 		request.addParameter(OAuth2ParameterNames.SCOPE,
 				StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
+		request.addParameter("custom-param-1", "custom-value-1");
 
 		return request;
 	}