Joe Grandja 2 сар өмнө
parent
commit
c4e8427a3a

+ 15 - 2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2024 the original author or authors.
+ * Copyright 2020-2025 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@ import java.util.Set;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 
+import org.springframework.core.log.LogMessage;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
@@ -109,6 +110,19 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
 			this.logger.trace("Retrieved authorization with user code");
 		}
 
+		OAuth2Authorization.Token<OAuth2UserCode> userCode = authorization.getToken(OAuth2UserCode.class);
+		if (!userCode.isActive()) {
+			if (!userCode.isInvalidated()) {
+				authorization = OAuth2Authorization.from(authorization).invalidate(userCode.getToken()).build();
+				this.authorizationService.save(authorization);
+				if (this.logger.isWarnEnabled()) {
+					this.logger.warn(LogMessage.format("Invalidated user code used by registered client '%s'",
+							authorization.getRegisteredClientId()));
+				}
+			}
+			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT);
+		}
+
 		Authentication principal = (Authentication) deviceVerificationAuthentication.getPrincipal();
 		if (!isPrincipalAuthenticated(principal)) {
 			if (this.logger.isTraceEnabled()) {
@@ -161,7 +175,6 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
 					requestedScopes, currentAuthorizedScopes);
 		}
 
-		OAuth2Authorization.Token<OAuth2UserCode> userCode = authorization.getToken(OAuth2UserCode.class);
 		// @formatter:off
 		authorization = OAuth2Authorization.from(authorization)
 				.principalName(principal.getName())

+ 84 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProviderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2023 the original author or authors.
+ * Copyright 2020-2025 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@ import java.time.Instant;
 import java.time.temporal.ChronoUnit;
 import java.util.Collections;
 import java.util.Map;
+import java.util.function.Consumer;
 import java.util.function.Function;
 
 import org.junit.jupiter.api.BeforeEach;
@@ -55,6 +56,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -145,10 +147,81 @@ public class OAuth2DeviceVerificationAuthenticationProviderTests {
 		verifyNoInteractions(this.registeredClientRepository, this.authorizationConsentService);
 	}
 
+	@Test
+	public void authenticateWhenUserCodeIsInvalidatedThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		// @formatter:off
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClient)
+				.token(createDeviceCode())
+				.token(createUserCode(), withInvalidated())
+				.attribute(OAuth2ParameterNames.SCOPE, registeredClient.getScopes())
+				.build();
+		// @formatter:on
+		given(this.authorizationService.findByToken(eq(USER_CODE),
+				eq(OAuth2DeviceVerificationAuthenticationProvider.USER_CODE_TOKEN_TYPE)))
+			.willReturn(authorization);
+		Authentication authentication = createAuthentication();
+		// @formatter:off
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.extracting(OAuth2AuthenticationException::getError)
+				.extracting(OAuth2Error::getErrorCode)
+				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+		// @formatter:on
+
+		verify(this.authorizationService).findByToken(USER_CODE,
+				OAuth2DeviceVerificationAuthenticationProvider.USER_CODE_TOKEN_TYPE);
+		verifyNoMoreInteractions(this.authorizationService);
+		verifyNoInteractions(this.registeredClientRepository, this.authorizationConsentService);
+	}
+
+	@Test
+	public void authenticateWhenUserCodeIsExpiredAndNotInvalidatedThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		// @formatter:off
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClient)
+				// Device code would also be expired but not relevant for this test
+				.token(createDeviceCode())
+				.token(createExpiredUserCode())
+				.attribute(OAuth2ParameterNames.SCOPE, registeredClient.getScopes())
+				.build();
+		// @formatter:on
+		given(this.authorizationService.findByToken(eq(USER_CODE),
+				eq(OAuth2DeviceVerificationAuthenticationProvider.USER_CODE_TOKEN_TYPE)))
+			.willReturn(authorization);
+		Authentication authentication = createAuthentication();
+		// @formatter:off
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.extracting(OAuth2AuthenticationException::getError)
+				.extracting(OAuth2Error::getErrorCode)
+				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+		// @formatter:on
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+		verify(this.authorizationService).findByToken(USER_CODE,
+				OAuth2DeviceVerificationAuthenticationProvider.USER_CODE_TOKEN_TYPE);
+		verify(this.authorizationService).save(authorizationCaptor.capture());
+		verifyNoMoreInteractions(this.authorizationService);
+		verifyNoInteractions(this.registeredClientRepository, this.authorizationConsentService);
+
+		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+		assertThat(updatedAuthorization.getToken(OAuth2UserCode.class)).extracting(isInvalidated()).isEqualTo(true);
+	}
+
 	@Test
 	public void authenticateWhenPrincipalNotAuthenticatedThenReturnUnauthenticated() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
-		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		// @formatter:off
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClient)
+				.token(createDeviceCode())
+				.token(createUserCode())
+				.attribute(OAuth2ParameterNames.SCOPE, registeredClient.getScopes())
+				.build();
+		// @formatter:on
 		TestingAuthenticationToken principal = new TestingAuthenticationToken("user", null);
 		Authentication authentication = new OAuth2DeviceVerificationAuthenticationToken(principal, USER_CODE,
 				Collections.emptyMap());
@@ -331,6 +404,15 @@ public class OAuth2DeviceVerificationAuthenticationProviderTests {
 		return new OAuth2UserCode(USER_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES));
 	}
 
+	private static OAuth2UserCode createExpiredUserCode() {
+		Instant issuedAt = Instant.now().minus(45, ChronoUnit.MINUTES);
+		return new OAuth2UserCode(USER_CODE, issuedAt, issuedAt.plus(30, ChronoUnit.MINUTES));
+	}
+
+	private static Consumer<Map<String, Object>> withInvalidated() {
+		return (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true);
+	}
+
 	private static Function<OAuth2Authorization.Token<? extends OAuth2Token>, Boolean> isInvalidated() {
 		return (token) -> token.getMetadata(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME);
 	}