Explorar el Código

Add Predicate for authorizationConsentRequired for device code grant

Introduces customizable Predicate to determine if user consent is
required in device authorization flows. Previously, device consent
handling used fixed logic. Now applications can define custom logic
for skipping or displaying consent pages.

Adds OAuth2DeviceVerificationAuthenticationContext and updates
OAuth2DeviceVerificationAuthenticationProvider with
setAuthorizationConsentRequired method.

Fixes gh-18016

Signed-off-by: Dinesh Gupta <dineshgupta630@outlook.com>
Joe Grandja hace 6 días
padre
commit
baa3b287d6

+ 156 - 0
oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationContext.java

@@ -0,0 +1,156 @@
+/*
+ * Copyright 2004-present the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+import org.springframework.lang.Nullable;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.util.Assert;
+
+/**
+ * An {@link OAuth2AuthenticationContext} that holds an
+ * {@link OAuth2DeviceVerificationAuthenticationToken} and additional information and is
+ * used when determining if authorization consent is required.
+ *
+ * @author Dinesh Gupta
+ * @since 7.0
+ * @see OAuth2AuthenticationContext
+ * @see OAuth2DeviceVerificationAuthenticationToken
+ * @see OAuth2DeviceVerificationAuthenticationProvider#setAuthorizationConsentRequired(java.util.function.Predicate)
+ */
+public final class OAuth2DeviceVerificationAuthenticationContext implements OAuth2AuthenticationContext {
+
+	private final Map<Object, Object> context;
+
+	private OAuth2DeviceVerificationAuthenticationContext(Map<Object, Object> context) {
+		this.context = Collections.unmodifiableMap(new HashMap<>(context));
+	}
+
+	@SuppressWarnings("unchecked")
+	@Nullable
+	@Override
+	public <V> V get(Object key) {
+		return hasKey(key) ? (V) this.context.get(key) : null;
+	}
+
+	@Override
+	public boolean hasKey(Object key) {
+		Assert.notNull(key, "key cannot be null");
+		return this.context.containsKey(key);
+	}
+
+	/**
+	 * Returns the {@link RegisteredClient registered client}.
+	 * @return the {@link RegisteredClient}
+	 */
+	public RegisteredClient getRegisteredClient() {
+		return get(RegisteredClient.class);
+	}
+
+	/**
+	 * Returns the {@link OAuth2Authorization authorization}.
+	 * @return the {@link OAuth2Authorization}
+	 */
+	public OAuth2Authorization getAuthorization() {
+		return get(OAuth2Authorization.class);
+	}
+
+	/**
+	 * Returns the {@link OAuth2AuthorizationConsent authorization consent}.
+	 * @return the {@link OAuth2AuthorizationConsent}, or {@code null} if not available
+	 */
+	@Nullable
+	public OAuth2AuthorizationConsent getAuthorizationConsent() {
+		return get(OAuth2AuthorizationConsent.class);
+	}
+
+	/**
+	 * Returns the requested scopes.
+	 * @return the requested scopes
+	 */
+	public Set<String> getRequestedScopes() {
+		Set<String> requestedScopes = getAuthorization().getAttribute(OAuth2ParameterNames.SCOPE);
+		return (requestedScopes != null) ? requestedScopes : Collections.emptySet();
+	}
+
+	/**
+	 * Constructs a new {@link Builder} with the provided
+	 * {@link OAuth2DeviceVerificationAuthenticationToken}.
+	 * @param authentication the {@link OAuth2DeviceVerificationAuthenticationToken}
+	 * @return the {@link Builder}
+	 */
+	public static Builder with(OAuth2DeviceVerificationAuthenticationToken authentication) {
+		return new Builder(authentication);
+	}
+
+	/**
+	 * A builder for {@link OAuth2DeviceVerificationAuthenticationContext}.
+	 */
+	public static final class Builder extends AbstractBuilder<OAuth2DeviceVerificationAuthenticationContext, Builder> {
+
+		private Builder(OAuth2DeviceVerificationAuthenticationToken authentication) {
+			super(authentication);
+		}
+
+		/**
+		 * Sets the {@link RegisteredClient registered client}.
+		 * @param registeredClient the {@link RegisteredClient}
+		 * @return the {@link Builder} for further configuration
+		 */
+		public Builder registeredClient(RegisteredClient registeredClient) {
+			return put(RegisteredClient.class, registeredClient);
+		}
+
+		/**
+		 * Sets the {@link OAuth2Authorization authorization}.
+		 * @param authorization the {@link OAuth2Authorization}
+		 * @return the {@link Builder} for further configuration
+		 */
+		public Builder authorization(OAuth2Authorization authorization) {
+			return put(OAuth2Authorization.class, authorization);
+		}
+
+		/**
+		 * Sets the {@link OAuth2AuthorizationConsent authorization consent}.
+		 * @param authorizationConsent the {@link OAuth2AuthorizationConsent}
+		 * @return the {@link Builder} for further configuration
+		 */
+		public Builder authorizationConsent(OAuth2AuthorizationConsent authorizationConsent) {
+			return put(OAuth2AuthorizationConsent.class, authorizationConsent);
+		}
+
+		/**
+		 * Builds a new {@link OAuth2DeviceVerificationAuthenticationContext}.
+		 * @return the {@link OAuth2DeviceVerificationAuthenticationContext}
+		 */
+		@Override
+		public OAuth2DeviceVerificationAuthenticationContext build() {
+			Assert.notNull(get(RegisteredClient.class), "registeredClient cannot be null");
+			Assert.notNull(get(OAuth2Authorization.class), "authorization cannot be null");
+			return new OAuth2DeviceVerificationAuthenticationContext(getContext());
+		}
+
+	}
+
+}

+ 42 - 4
oauth2/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java

@@ -19,6 +19,7 @@ package org.springframework.security.oauth2.server.authorization.authentication;
 import java.security.Principal;
 import java.util.Base64;
 import java.util.Set;
+import java.util.function.Predicate;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -79,6 +80,8 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
 
 	private final OAuth2AuthorizationConsentService authorizationConsentService;
 
+	private Predicate<OAuth2DeviceVerificationAuthenticationContext> authorizationConsentRequired = OAuth2DeviceVerificationAuthenticationProvider::isAuthorizationConsentRequired;
+
 	/**
 	 * Constructs an {@code OAuth2DeviceVerificationAuthenticationProvider} using the
 	 * provided parameters.
@@ -143,10 +146,18 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
 
 		Set<String> requestedScopes = authorization.getAttribute(OAuth2ParameterNames.SCOPE);
 
+		OAuth2DeviceVerificationAuthenticationContext.Builder authenticationContextBuilder = OAuth2DeviceVerificationAuthenticationContext
+			.with(deviceVerificationAuthentication)
+			.registeredClient(registeredClient)
+			.authorization(authorization);
+
 		OAuth2AuthorizationConsent currentAuthorizationConsent = this.authorizationConsentService
 			.findById(registeredClient.getId(), principal.getName());
+		if (currentAuthorizationConsent != null) {
+			authenticationContextBuilder.authorizationConsent(currentAuthorizationConsent);
+		}
 
-		if (requiresAuthorizationConsent(requestedScopes, currentAuthorizationConsent)) {
+		if (this.authorizationConsentRequired.test(authenticationContextBuilder.build())) {
 			String state = DEFAULT_STATE_GENERATOR.generateKey();
 			authorization = OAuth2Authorization.from(authorization)
 				.principalName(principal.getName())
@@ -204,10 +215,37 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
 		return OAuth2DeviceVerificationAuthenticationToken.class.isAssignableFrom(authentication);
 	}
 
-	private static boolean requiresAuthorizationConsent(Set<String> requestedScopes,
-			OAuth2AuthorizationConsent authorizationConsent) {
+	/**
+	 * Sets the {@code Predicate} used to determine if authorization consent is required.
+	 *
+	 * <p>
+	 * The {@link OAuth2DeviceVerificationAuthenticationContext} gives the predicate
+	 * access to the {@link OAuth2DeviceVerificationAuthenticationToken}, as well as, the
+	 * following context attributes:
+	 * <ul>
+	 * <li>The {@link RegisteredClient} associated with the device authorization
+	 * request.</li>
+	 * <li>The {@link OAuth2Authorization} containing the device authorization request
+	 * parameters.</li>
+	 * <li>The {@link OAuth2AuthorizationConsent} previously granted to the
+	 * {@link RegisteredClient}, or {@code null} if not available.</li>
+	 * </ul>
+	 * </p>
+	 * @param authorizationConsentRequired the {@code Predicate} used to determine if
+	 * authorization consent is required
+	 */
+	public void setAuthorizationConsentRequired(
+			Predicate<OAuth2DeviceVerificationAuthenticationContext> authorizationConsentRequired) {
+		Assert.notNull(authorizationConsentRequired, "authorizationConsentRequired cannot be null");
+		this.authorizationConsentRequired = authorizationConsentRequired;
+	}
+
+	private static boolean isAuthorizationConsentRequired(
+			OAuth2DeviceVerificationAuthenticationContext authenticationContext) {
 
-		if (authorizationConsent != null && authorizationConsent.getScopes().containsAll(requestedScopes)) {
+		if (authenticationContext.getAuthorizationConsent() != null && authenticationContext.getAuthorizationConsent()
+			.getScopes()
+			.containsAll(authenticationContext.getRequestedScopes())) {
 			return false;
 		}
 

+ 33 - 0
oauth2/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java

@@ -23,6 +23,7 @@ import java.util.Collections;
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Function;
+import java.util.function.Predicate;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -125,6 +126,13 @@ public class OAuth2DeviceVerificationAuthenticationProviderTests {
 		// @formatter:on
 	}
 
+	@Test
+	public void setAuthorizationConsentRequiredWhenNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+			.isThrownBy(() -> this.authenticationProvider.setAuthorizationConsentRequired(null))
+			.withMessage("authorizationConsentRequired cannot be null");
+	}
+
 	@Test
 	public void supportsWhenTypeOAuth2DeviceVerificationAuthenticationTokenThenReturnTrue() {
 		assertThat(this.authenticationProvider.supports(OAuth2DeviceVerificationAuthenticationToken.class)).isTrue();
@@ -382,6 +390,31 @@ public class OAuth2DeviceVerificationAuthenticationProviderTests {
 			.isEqualTo(authenticationResult.getState());
 	}
 
+	@Test
+	public void authenticateWhenCustomAuthorizationConsentRequiredThenUsed() {
+		@SuppressWarnings("unchecked")
+		Predicate<OAuth2DeviceVerificationAuthenticationContext> authorizationConsentRequired = mock(Predicate.class);
+		this.authenticationProvider.setAuthorizationConsentRequired(authorizationConsentRequired);
+
+		// @formatter:off
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
+				.authorizationGrantType(AuthorizationGrantType.DEVICE_CODE)
+				.token(createDeviceCode())
+				.token(createUserCode())
+				.attributes(Map::clear)
+				.attribute(OAuth2ParameterNames.SCOPE, registeredClient.getScopes())
+				.build();
+		// @formatter:on
+		Authentication authentication = createAuthentication();
+		given(this.registeredClientRepository.findById(anyString())).willReturn(registeredClient);
+		given(this.authorizationService.findByToken(anyString(), any(OAuth2TokenType.class))).willReturn(authorization);
+
+		this.authenticationProvider.authenticate(authentication);
+
+		verify(authorizationConsentRequired).test(any());
+	}
+
 	private static void mockAuthorizationServerContext() {
 		AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().build();
 		TestAuthorizationServerContext authorizationServerContext = new TestAuthorizationServerContext(