Przeglądaj źródła

Customize OAuth2AuthorizationConsent prior to saving

Closes gh-436
Steve Riesenberg 4 lat temu
rodzic
commit
4ce999c014

+ 212 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationConsentContext.java

@@ -0,0 +1,212 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Consumer;
+
+import org.springframework.lang.Nullable;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.context.Context;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+
+/**
+ * A context that holds an {@link OAuth2AuthorizationConsent.Builder} and (optionally) additional information
+ * and is used when customizing the building of {@link OAuth2AuthorizationConsent}.
+ *
+ * @author Steve Riesenberg
+ * @since 0.2.1
+ * @see Context
+ */
+public final class OAuth2AuthorizationConsentContext implements Context {
+	private final Map<Object, Object> context;
+
+	/**
+	 * Constructs an {@code OAuth2AuthorizationConsentContext} using the provided parameters.
+	 *
+	 * @param context a {@code Map} of additional context information
+	 */
+	private OAuth2AuthorizationConsentContext(@Nullable Map<Object, Object> context) {
+		this.context = new HashMap<>();
+		if (!CollectionUtils.isEmpty(context)) {
+			this.context.putAll(context);
+		}
+	}
+
+	/**
+	 * Returns the {@link OAuth2AuthorizationConsent.Builder authorization consent builder}.
+	 *
+	 * @return the {@link OAuth2AuthorizationConsent.Builder}
+	 */
+	public OAuth2AuthorizationConsent.Builder getAuthorizationConsentBuilder() {
+		return get(OAuth2AuthorizationConsent.Builder.class);
+	}
+
+	/**
+	 * Returns the {@link Authentication} representing the {@code Principal} resource owner (or client).
+	 *
+	 * @param <T> the type of the {@code Authentication}
+	 * @return the {@link Authentication} representing the {@code Principal} resource owner (or client)
+	 */
+	@Nullable
+	public <T extends Authentication> T getPrincipal() {
+		return get(Builder.PRINCIPAL_AUTHENTICATION_KEY);
+	}
+
+	/**
+	 * Returns the {@link RegisteredClient registered client}.
+	 *
+	 * @return the {@link RegisteredClient}, or {@code null} if not available
+	 */
+	@Nullable
+	public RegisteredClient getRegisteredClient() {
+		return get(RegisteredClient.class);
+	}
+
+	/**
+	 * Returns the {@link OAuth2Authorization authorization}.
+	 *
+	 * @return the {@link OAuth2Authorization}, or {@code null} if not available
+	 */
+	@Nullable
+	public OAuth2Authorization getAuthorization() {
+		return get(OAuth2Authorization.class);
+	}
+
+	/**
+	 * Returns the {@link OAuth2AuthorizationRequest authorization request}.
+	 *
+	 * @return the {@link OAuth2AuthorizationRequest}, or {@code null} if not available
+	 */
+	@Nullable
+	public OAuth2AuthorizationRequest getAuthorizationRequest() {
+		return get(OAuth2AuthorizationRequest.class);
+	}
+
+	@SuppressWarnings("unchecked")
+	@Override
+	public <V> V get(Object key) {
+		return (V) this.context.get(key);
+	}
+
+	@Override
+	public boolean hasKey(Object key) {
+		return this.context.containsKey(key);
+	}
+
+	/**
+	 * Constructs a new {@link Builder} with the provided {@link OAuth2AuthorizationConsent.Builder}.
+	 *
+	 * @param authorizationConsentBuilder the {@link OAuth2AuthorizationConsent.Builder} to initialize the builder
+	 * @return the {@link Builder}
+	 */
+	public static OAuth2AuthorizationConsentContext.Builder with(OAuth2AuthorizationConsent.Builder authorizationConsentBuilder) {
+		return new Builder(authorizationConsentBuilder);
+	}
+
+	/**
+	 * A builder for {@link OAuth2AuthorizationConsentContext}.
+	 */
+	public static final class Builder {
+		private static final String PRINCIPAL_AUTHENTICATION_KEY =
+				Authentication.class.getName().concat(".PRINCIPAL");
+		private final Map<Object, Object> context = new HashMap<>();
+
+		private Builder(OAuth2AuthorizationConsent.Builder authorizationConsentBuilder) {
+			Assert.notNull(authorizationConsentBuilder, "authorizationConsentBuilder cannot be null");
+			put(OAuth2AuthorizationConsent.Builder.class, authorizationConsentBuilder);
+		}
+
+		/**
+		 * Sets the {@link Authentication} representing the {@code Principal} resource owner (or client).
+		 *
+		 * @param principal the {@link Authentication} representing the {@code Principal} resource owner (or client)
+		 * @return the {@link Builder} for further configuration
+		 */
+		public Builder principal(Authentication principal) {
+			return put(PRINCIPAL_AUTHENTICATION_KEY, principal);
+		}
+
+		/**
+		 * Sets the {@link RegisteredClient registered client}.
+		 *
+		 * @param registeredClient the {@link RegisteredClient}
+		 * @return the {@link Builder} for further configuration
+		 */
+		public Builder registeredClient(RegisteredClient registeredClient) {
+			return put(RegisteredClient.class, registeredClient);
+		}
+
+		/**
+		 * Sets the {@link OAuth2Authorization authorization}.
+		 *
+		 * @param authorization the {@link OAuth2Authorization}
+		 * @return the {@link Builder} for further configuration
+		 */
+		public Builder authorization(OAuth2Authorization authorization) {
+			return put(OAuth2Authorization.class, authorization);
+		}
+
+		/**
+		 * Sets the {@link OAuth2AuthorizationRequest authorization request}.
+		 *
+		 * @param authorizationRequest the {@link OAuth2AuthorizationRequest}
+		 * @return the {@link Builder} for further configuration
+		 */
+		public Builder authorizationRequest(OAuth2AuthorizationRequest authorizationRequest) {
+			return put(OAuth2AuthorizationRequest.class, authorizationRequest);
+		}
+
+		/**
+		 * Associates an attribute.
+		 *
+		 * @param key the key for the attribute
+		 * @param value the value of the attribute
+		 * @return the {@link OAuth2TokenContext.AbstractBuilder} for further configuration
+		 */
+		public Builder put(Object key, Object value) {
+			Assert.notNull(key, "key cannot be null");
+			Assert.notNull(value, "value cannot be null");
+			this.context.put(key, value);
+			return this;
+		}
+
+		/**
+		 * A {@code Consumer} of the attributes {@code Map}
+		 * allowing the ability to add, replace, or remove.
+		 *
+		 * @param contextConsumer a {@link Consumer} of the attributes {@code Map}
+		 * @return the {@link Builder} for further configuration
+		 */
+		public Builder context(Consumer<Map<Object, Object>> contextConsumer) {
+			contextConsumer.accept(this.context);
+			return this;
+		}
+
+		/**
+		 * Builds a new {@link OAuth2AuthorizationConsentContext}.
+		 *
+		 * @return the {@link OAuth2AuthorizationConsentContext}
+		 */
+		public OAuth2AuthorizationConsentContext build() {
+			return new OAuth2AuthorizationConsentContext(this.context);
+		}
+	}
+}

+ 53 - 11
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

@@ -29,6 +29,7 @@ import java.util.function.Supplier;
 
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.authentication.AuthenticationProvider;
+import org.springframework.security.config.Customizer;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
@@ -46,6 +47,7 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@@ -82,6 +84,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 	private final OAuth2AuthorizationConsentService authorizationConsentService;
 	private Supplier<String> authorizationCodeGenerator = DEFAULT_AUTHORIZATION_CODE_GENERATOR::generateKey;
 	private Function<String, OAuth2AuthenticationValidator> authenticationValidatorResolver = DEFAULT_AUTHENTICATION_VALIDATOR_RESOLVER;
+	private Customizer<OAuth2AuthorizationConsentContext> authorizationConsentCustomizer;
 
 	/**
 	 * Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationProvider} using the provided parameters.
@@ -145,6 +148,30 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 		this.authenticationValidatorResolver = authenticationValidatorResolver;
 	}
 
+	/**
+	 * Sets the {@link Customizer} providing access to the {@link OAuth2AuthorizationConsentContext} containing an
+	 * {@link OAuth2AuthorizationConsent.Builder}.
+	 *
+	 * <p>
+	 * The following context attributes are available:
+	 * <ul>
+	 * <li>The {@link OAuth2AuthorizationConsent.Builder} used to build the authorization consent
+	 * prior to {@link OAuth2AuthorizationConsentService#save(OAuth2AuthorizationConsent)}</li>
+	 * <li>The {@link Authentication authentication principal} of type
+	 * {@link OAuth2AuthorizationCodeRequestAuthenticationToken}</li>
+	 * <li>The {@link OAuth2Authorization} associated with the state token presented in the
+	 * authorization consent request.</li>
+	 * <li>The {@link OAuth2AuthorizationRequest} requiring the resource owner's consent.</li>
+	 * </ul>
+	 *
+	 * @param authorizationConsentCustomizer the {@link Customizer} providing access to the
+	 * {@link OAuth2AuthorizationConsentContext} containing an {@link OAuth2AuthorizationConsent.Builder}
+	 */
+	public void setAuthorizationConsentCustomizer(Customizer<OAuth2AuthorizationConsentContext> authorizationConsentCustomizer) {
+		Assert.notNull(authorizationConsentCustomizer, "authorizationConsentCustomizer cannot be null");
+		this.authorizationConsentCustomizer = authorizationConsentCustomizer;
+	}
+
 	private Authentication authenticateAuthorizationRequest(Authentication authentication) throws AuthenticationException {
 		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
 				(OAuth2AuthorizationCodeRequestAuthenticationToken) authentication;
@@ -301,7 +328,8 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 		Set<String> currentAuthorizedScopes = currentAuthorizationConsent != null ?
 				currentAuthorizationConsent.getScopes() : Collections.emptySet();
 
-		if (authorizedScopes.isEmpty() && currentAuthorizedScopes.isEmpty()) {
+		if (authorizedScopes.isEmpty() && currentAuthorizedScopes.isEmpty()
+				&& authorizationCodeRequestAuthentication.getAdditionalParameters().isEmpty()) {
 			// Authorization consent denied
 			this.authorizationService.remove(authorization);
 			throwError(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2ParameterNames.CLIENT_ID,
@@ -321,16 +349,30 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 			}
 		}
 
-		if (!authorizedScopes.isEmpty() && !authorizedScopes.equals(currentAuthorizedScopes)) {
-			OAuth2AuthorizationConsent.Builder authorizationConsentBuilder;
-			if (currentAuthorizationConsent != null) {
-				authorizationConsentBuilder = OAuth2AuthorizationConsent.from(currentAuthorizationConsent);
-			} else {
-				authorizationConsentBuilder = OAuth2AuthorizationConsent.withId(
-						authorization.getRegisteredClientId(), authorization.getPrincipalName());
-			}
-			authorizedScopes.forEach(authorizationConsentBuilder::scope);
-			OAuth2AuthorizationConsent authorizationConsent = authorizationConsentBuilder.build();
+		OAuth2AuthorizationConsent.Builder authorizationConsentBuilder;
+		if (currentAuthorizationConsent != null) {
+			authorizationConsentBuilder = OAuth2AuthorizationConsent.from(currentAuthorizationConsent);
+		} else {
+			authorizationConsentBuilder = OAuth2AuthorizationConsent.withId(
+					authorization.getRegisteredClientId(), authorization.getPrincipalName());
+		}
+		authorizedScopes.forEach(authorizationConsentBuilder::scope);
+
+		if (this.authorizationConsentCustomizer != null) {
+			// @formatter:off
+			OAuth2AuthorizationConsentContext authorizationConsentContext =
+					OAuth2AuthorizationConsentContext.with(authorizationConsentBuilder)
+							.principal(authorizationCodeRequestAuthentication)
+							.registeredClient(registeredClient)
+							.authorization(authorization)
+							.authorizationRequest(authorizationRequest)
+							.build();
+			// @formatter:on
+			this.authorizationConsentCustomizer.customize(authorizationConsentContext);
+		}
+
+		OAuth2AuthorizationConsent authorizationConsent = authorizationConsentBuilder.build();
+		if (!authorizationConsent.equals(currentAuthorizationConsent)) {
 			this.authorizationConsentService.save(authorizationConsent);
 		}
 

+ 156 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java

@@ -23,14 +23,17 @@ import java.security.Principal;
 import java.text.MessageFormat;
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
+import java.util.Arrays;
 import java.util.Base64;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.proc.SecurityContext;
+import org.assertj.core.matcher.AssertionMatcher;
 import org.junit.After;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
@@ -52,12 +55,14 @@ import org.springframework.mock.http.client.MockClientHttpResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.config.Customizer;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
 import org.springframework.security.config.test.SpringTestRule;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.crypto.password.NoOpPasswordEncoder;
 import org.springframework.security.crypto.password.PasswordEncoder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
@@ -77,10 +82,13 @@ import org.springframework.security.oauth2.server.authorization.JdbcOAuth2Author
 import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository.RegisteredClientParametersMapper;
@@ -524,6 +532,60 @@ public class OAuth2AuthorizationCodeGrantTests {
 		assertThat(authorization).isNotNull();
 	}
 
+	@Test
+	public void requestWhenCustomConsentCustomizerConfiguredThenUsed() throws Exception {
+		this.spring.register(AuthorizationServerConfigurationCustomConsentRequest.class).autowire();
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientSettings(ClientSettings.builder()
+						.requireAuthorizationConsent(true)
+						.setting("custom.allowed-authorities", "authority-1 authority-2")
+						.build())
+				.build();
+		this.registeredClientRepository.save(registeredClient);
+
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
+				.build();
+		this.authorizationService.save(authorization);
+
+		MvcResult mvcResult = this.mvc.perform(post(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
+				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
+				.param("authority", "authority-1 authority-2")
+				.param(OAuth2ParameterNames.STATE, "state")
+				.with(user("principal")))
+				.andExpect(status().is3xxRedirection())
+				.andReturn();
+
+		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
+		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state");
+
+		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
+		OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
+
+		mvcResult = this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
+				.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
+				.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
+				.andExpect(status().isOk())
+				.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
+				.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
+				.andExpect(jsonPath("$.access_token").isNotEmpty())
+				.andExpect(jsonPath("$.access_token").value(new AssertionMatcher<String>() {
+					@Override
+					public void assertion(String accessToken) throws AssertionError {
+						Jwt jwt = jwtDecoder.decode(accessToken);
+						assertThat(jwt.getClaimAsStringList(AUTHORITIES_CLAIM))
+								.containsExactlyInAnyOrder("authority-1", "authority-2");
+					}
+				}))
+				.andExpect(jsonPath("$.token_type").isNotEmpty())
+				.andExpect(jsonPath("$.expires_in").isNotEmpty())
+				.andExpect(jsonPath("$.refresh_token").isNotEmpty())
+				.andExpect(jsonPath("$.scope").doesNotExist())
+				.andReturn();
+
+		String json = mvcResult.getResponse().getContentAsString();
+	}
+
 	@Test
 	public void requestWhenAuthorizationEndpointCustomizedThenUsed() throws Exception {
 		this.spring.register(AuthorizationServerConfigurationCustomAuthorizationEndpoint.class).autowire();
@@ -722,6 +784,100 @@ public class OAuth2AuthorizationCodeGrantTests {
 		// @formatter:on
 	}
 
+	@EnableWebSecurity
+	static class AuthorizationServerConfigurationCustomConsentRequest extends AuthorizationServerConfiguration {
+		@Autowired
+		private RegisteredClientRepository registeredClientRepository;
+
+		@Autowired
+		private OAuth2AuthorizationService authorizationService;
+
+		@Autowired
+		private OAuth2AuthorizationConsentService authorizationConsentService;
+
+		// @formatter:off
+		@Bean
+		public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception {
+			OAuth2AuthorizationServerConfigurer<HttpSecurity> authorizationServerConfigurer =
+					new OAuth2AuthorizationServerConfigurer<>();
+			authorizationServerConfigurer
+					.authorizationEndpoint(authorizationEndpoint ->
+							authorizationEndpoint.authenticationProvider(createProvider()));
+			RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher();
+
+			http
+					.requestMatcher(endpointsMatcher)
+					.authorizeRequests(authorizeRequests ->
+							authorizeRequests.anyRequest().authenticated()
+					)
+					.csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher))
+					.apply(authorizationServerConfigurer);
+			return http.build();
+		}
+		// @formatter:on
+
+		@Bean
+		@Override
+		OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer() {
+			return context -> {
+				if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getAuthorizationGrantType()) &&
+						OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType())) {
+					OAuth2AuthorizationConsent authorizationConsent = authorizationConsentService.findById(
+							context.getRegisteredClient().getId(), context.getPrincipal().getName());
+
+					Set<String> authorities = new HashSet<>();
+					for (GrantedAuthority authority : authorizationConsent.getAuthorities()) {
+						authorities.add(authority.getAuthority());
+					}
+					context.getClaims().claim(AUTHORITIES_CLAIM, authorities);
+				}
+			};
+		}
+
+		private AuthenticationProvider createProvider() {
+			OAuth2AuthorizationCodeRequestAuthenticationProvider authorizationCodeRequestAuthenticationProvider =
+					new OAuth2AuthorizationCodeRequestAuthenticationProvider(
+							this.registeredClientRepository,
+							this.authorizationService,
+							this.authorizationConsentService);
+			authorizationCodeRequestAuthenticationProvider.setAuthorizationConsentCustomizer(new ConsentCustomizer());
+
+			return authorizationCodeRequestAuthenticationProvider;
+		}
+
+		static class ConsentCustomizer implements Customizer<OAuth2AuthorizationConsentContext> {
+			@Override
+			public void customize(OAuth2AuthorizationConsentContext authorizationConsentContext) {
+				OAuth2AuthorizationConsent.Builder authorizationConsentBuilder =
+						authorizationConsentContext.getAuthorizationConsentBuilder();
+				OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
+						authorizationConsentContext.getPrincipal();
+				Map<String, Object> additionalParameters =
+						authorizationCodeRequestAuthentication.getAdditionalParameters();
+				RegisteredClient registeredClient = authorizationConsentContext.getRegisteredClient();
+				ClientSettings clientSettings = registeredClient.getClientSettings();
+
+				Set<String> requestedAuthorities = authorities((String) additionalParameters.get("authority"));
+				Set<String> allowedAuthorities = authorities(clientSettings.getSetting("custom.allowed-authorities"));
+				for (String requestedAuthority : requestedAuthorities) {
+					if (allowedAuthorities.contains(requestedAuthority)) {
+						authorizationConsentBuilder.authority(new SimpleGrantedAuthority(requestedAuthority));
+					}
+				}
+			}
+
+			private static Set<String> authorities(String param) {
+				Set<String> authorities = new HashSet<>();
+				if (param != null) {
+					List<String> authorityValues = Arrays.asList(param.split(" "));
+					authorities.addAll(authorityValues);
+				}
+
+				return authorities;
+			}
+		}
+	}
+
 	@EnableWebSecurity
 	static class AuthorizationServerConfigurationCustomAuthorizationEndpoint extends AuthorizationServerConfiguration {
 		// @formatter:off

+ 91 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationConsentContextTests.java

@@ -0,0 +1,91 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization;
+
+import org.junit.Test;
+
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for {@link OAuth2AuthorizationConsentContext}.
+ *
+ * @author Steve Riesenberg
+ */
+public class OAuth2AuthorizationConsentContextTests {
+
+	@Test
+	public void withWhenAuthorizationConsentBuilderNullThenIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2AuthorizationConsentContext.with(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationConsentBuilder cannot be null");
+	}
+
+	@Test
+	public void setWhenValueNullThenThrowIllegalArgumentException() {
+		OAuth2AuthorizationConsentContext.Builder builder = OAuth2AuthorizationConsentContext
+				.with(OAuth2AuthorizationConsent.withId("some-client", "some-principal"));
+		assertThatThrownBy(() -> builder.principal(null))
+				.isInstanceOf(IllegalArgumentException.class);
+		assertThatThrownBy(() -> builder.registeredClient(null))
+				.isInstanceOf(IllegalArgumentException.class);
+		assertThatThrownBy(() -> builder.authorization(null))
+				.isInstanceOf(IllegalArgumentException.class);
+		assertThatThrownBy(() -> builder.authorizationRequest(null))
+				.isInstanceOf(IllegalArgumentException.class);
+		assertThatThrownBy(() -> builder.put(null, ""))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void buildWhenAllValuesProvidedThenAllValuesAreSet() {
+		OAuth2AuthorizationConsent.Builder authorizationConsentBuilder = OAuth2AuthorizationConsent
+				.withId("some-client", "some-principal");
+		TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "password");
+		OAuth2AuthorizationCodeRequestAuthenticationToken authentication =
+				OAuth2AuthorizationCodeRequestAuthenticationToken.with("test-client", principal)
+						.authorizationUri("https://provider.com/oauth2/authorize")
+						.build();
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+				OAuth2AuthorizationRequest.class.getName());
+
+		OAuth2AuthorizationConsentContext context = OAuth2AuthorizationConsentContext
+				.with(authorizationConsentBuilder)
+				.principal(authentication)
+				.registeredClient(registeredClient)
+				.authorization(authorization)
+				.authorizationRequest(authorizationRequest)
+				.put("custom-key-1", "custom-value-1")
+				.context(ctx -> ctx.put("custom-key-2", "custom-value-2"))
+				.build();
+
+		assertThat(context.getAuthorizationConsentBuilder()).isEqualTo(authorizationConsentBuilder);
+		assertThat(context.<OAuth2AuthorizationCodeRequestAuthenticationToken>getPrincipal()).isEqualTo(authentication);
+		assertThat(context.getRegisteredClient()).isEqualTo(registeredClient);
+		assertThat(context.getAuthorization()).isEqualTo(authorization);
+		assertThat(context.getAuthorizationRequest()).isEqualTo(authorizationRequest);
+		assertThat(context.<String>get("custom-key-1")).isEqualTo("custom-value-1");
+		assertThat(context.<String>get("custom-key-2")).isEqualTo("custom-value-2");
+	}
+}

+ 57 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java

@@ -29,8 +29,10 @@ import org.junit.Test;
 import org.mockito.ArgumentCaptor;
 
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.config.Customizer;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationCode;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
@@ -41,8 +43,8 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
-import org.springframework.security.oauth2.core.OAuth2AuthorizationCode;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
@@ -129,6 +131,13 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
 				.hasMessage("authenticationValidatorResolver cannot be null");
 	}
 
+	@Test
+	public void setAuthorizationConsentCustomizerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authenticationProvider.setAuthorizationConsentCustomizer(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationConsentCustomizer cannot be null");
+	}
+
 	@Test
 	public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@@ -773,6 +782,53 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
 		OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult =
 				(OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication);
 
+		assertAuthorizationConsentRequestWithAuthorizationCodeResult(registeredClient, authorization, authenticationResult);
+	}
+
+	@Test
+	public void authenticateWhenCustomAuthorizationConsentCustomizerThenUsed() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.build();
+		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
+				.principalName(this.principal.getName())
+				.build();
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
+		Set<String> authorizedScopes = authorizationRequest.getScopes();
+		OAuth2AuthorizationCodeRequestAuthenticationToken authentication =
+				authorizationConsentRequestAuthentication(registeredClient, this.principal)
+						.scopes(authorizedScopes)		// Approve all scopes
+						.build();
+		when(this.authorizationService.findByToken(eq(authentication.getState()), eq(STATE_TOKEN_TYPE)))
+				.thenReturn(authorization);
+
+		@SuppressWarnings("unchecked")
+		Customizer<OAuth2AuthorizationConsentContext> authorizationConsentCustomizer = mock(Customizer.class);
+		this.authenticationProvider.setAuthorizationConsentCustomizer(authorizationConsentCustomizer);
+
+		OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult =
+				(OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+		assertAuthorizationConsentRequestWithAuthorizationCodeResult(registeredClient, authorization, authenticationResult);
+
+		ArgumentCaptor<OAuth2AuthorizationConsentContext> contextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationConsentContext.class);
+		verify(authorizationConsentCustomizer).customize(contextCaptor.capture());
+
+		OAuth2AuthorizationConsentContext context = contextCaptor.getValue();
+		assertThat((Authentication) context.getPrincipal()).isEqualTo(authentication);
+		assertThat(context.get(OAuth2AuthorizationConsent.Builder.class)).isInstanceOf(OAuth2AuthorizationConsent.Builder.class);
+		assertThat(context.get(OAuth2Authorization.class)).isInstanceOf(OAuth2Authorization.class);
+		assertThat(context.get(OAuth2AuthorizationRequest.class)).isInstanceOf(OAuth2AuthorizationRequest.class);
+	}
+
+	private void assertAuthorizationConsentRequestWithAuthorizationCodeResult(
+			RegisteredClient registeredClient,
+			OAuth2Authorization authorization,
+			OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult) {
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
+		Set<String> authorizedScopes = authorizationRequest.getScopes();
+
 		ArgumentCaptor<OAuth2AuthorizationConsent> authorizationConsentCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationConsent.class);
 		verify(this.authorizationConsentService).save(authorizationConsentCaptor.capture());
 		OAuth2AuthorizationConsent authorizationConsent = authorizationConsentCaptor.getValue();