Explorar o código

Customize when user info is called

Closes gh-13259
Steve Riesenberg hai 1 ano
pai
achega
96e3e4f8b1

+ 26 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
 import java.util.function.Function;
+import java.util.function.Predicate;
 
 import reactor.core.publisher.Mono;
 
@@ -33,6 +34,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
 import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -71,6 +73,8 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
 	private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
 			clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
 
+	private Predicate<OidcUserRequest> retrieveUserInfo = OidcUserRequestUtils::shouldRetrieveUserInfo;
+
 	/**
 	 * Returns the default {@link Converter}'s used for type conversion of claim values
 	 * for an {@link OidcUserInfo}.
@@ -123,7 +127,7 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
 	}
 
 	private Mono<OidcUserInfo> getUserInfo(OidcUserRequest userRequest) {
-		if (!OidcUserRequestUtils.shouldRetrieveUserInfo(userRequest)) {
+		if (!this.retrieveUserInfo.test(userRequest)) {
 			return Mono.empty();
 		}
 		// @formatter:off
@@ -169,4 +173,24 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
 		this.claimTypeConverterFactory = claimTypeConverterFactory;
 	}
 
+	/**
+	 * Sets the {@code Predicate} used to determine if the UserInfo Endpoint should be
+	 * called to retrieve information about the End-User (Resource Owner).
+	 * <p>
+	 * By default, the UserInfo Endpoint is called if all of the following are true:
+	 * <ul>
+	 * <li>The user info endpoint is defined on the ClientRegistration</li>
+	 * <li>The Client Registration uses the
+	 * {@link AuthorizationGrantType#AUTHORIZATION_CODE} and scopes in the access token
+	 * are defined in the {@link ClientRegistration}</li>
+	 * </ul>
+	 * @param retrieveUserInfo the function used to determine if the UserInfo Endpoint
+	 * should be called
+	 * @since 6.3
+	 */
+	public final void setRetrieveUserInfo(Predicate<OidcUserRequest> retrieveUserInfo) {
+		Assert.notNull(retrieveUserInfo, "retrieveUserInfo cannot be null");
+		this.retrieveUserInfo = retrieveUserInfo;
+	}
+
 }

+ 30 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
 import java.util.function.Function;
+import java.util.function.Predicate;
 
 import org.springframework.core.convert.TypeDescriptor;
 import org.springframework.core.convert.converter.Converter;
@@ -78,6 +79,8 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 	private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
 			clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
 
+	private Predicate<OidcUserRequest> retrieveUserInfo = this::shouldRetrieveUserInfo;
+
 	/**
 	 * Returns the default {@link Converter}'s used for type conversion of claim values
 	 * for an {@link OidcUserInfo}.
@@ -105,7 +108,7 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 	public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
 		Assert.notNull(userRequest, "userRequest cannot be null");
 		OidcUserInfo userInfo = null;
-		if (this.shouldRetrieveUserInfo(userRequest)) {
+		if (this.retrieveUserInfo.test(userRequest)) {
 			OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest);
 			Map<String, Object> claims = getClaims(userRequest, oauth2User);
 			userInfo = new OidcUserInfo(claims);
@@ -221,10 +224,35 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 	 * resource will be requested, otherwise it will not.
 	 * @param accessibleScopes the scope(s) that allow access to the user info resource
 	 * @since 5.2
+	 * @deprecated Use {@link #setRetrieveUserInfo(Predicate)} instead
 	 */
+	@Deprecated(since = "6.3", forRemoval = true)
 	public final void setAccessibleScopes(Set<String> accessibleScopes) {
 		Assert.notNull(accessibleScopes, "accessibleScopes cannot be null");
 		this.accessibleScopes = accessibleScopes;
 	}
 
+	/**
+	 * Sets the {@code Predicate} used to determine if the UserInfo Endpoint should be
+	 * called to retrieve information about the End-User (Resource Owner).
+	 * <p>
+	 * By default, the UserInfo Endpoint is called if all of the following are true:
+	 * <ul>
+	 * <li>The user info endpoint is defined on the ClientRegistration</li>
+	 * <li>The Client Registration uses the
+	 * {@link AuthorizationGrantType#AUTHORIZATION_CODE}</li>
+	 * <li>The access token contains one or more scopes allowed to access the UserInfo
+	 * Endpoint ({@link OidcScopes#PROFILE profile}, {@link OidcScopes#EMAIL email},
+	 * {@link OidcScopes#ADDRESS address} or {@link OidcScopes#PHONE phone}) or the access
+	 * token scopes are empty</li>
+	 * </ul>
+	 * @param retrieveUserInfo the function used to determine if the UserInfo Endpoint
+	 * should be called
+	 * @since 6.3
+	 */
+	public final void setRetrieveUserInfo(Predicate<OidcUserRequest> retrieveUserInfo) {
+		Assert.notNull(retrieveUserInfo, "retrieveUserInfo cannot be null");
+		this.retrieveUserInfo = retrieveUserInfo;
+	}
+
 }

+ 52 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java

@@ -24,6 +24,7 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.function.Function;
+import java.util.function.Predicate;
 
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
@@ -107,6 +108,15 @@ public class OidcReactiveOAuth2UserServiceTests {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setClaimTypeConverterFactory(null));
 	}
 
+	@Test
+	public void setRetrieveUserInfoWhenNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.userService.setRetrieveUserInfo(null))
+				.withMessage("retrieveUserInfo cannot be null");
+		// @formatter:on
+	}
+
 	@Test
 	public void loadUserWhenUserInfoUriNullThenUserInfoNotRetrieved() {
 		this.registration.userInfoUri(null);
@@ -183,6 +193,48 @@ public class OidcReactiveOAuth2UserServiceTests {
 		verify(customClaimTypeConverterFactory).apply(same(userRequest.getClientRegistration()));
 	}
 
+	@Test
+	public void loadUserWhenTokenScopesIsEmptyThenUserInfoNotRetrieved() {
+		// @formatter:off
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(
+				this.accessToken.getTokenType(),
+				this.accessToken.getTokenValue(),
+				this.accessToken.getIssuedAt(),
+				this.accessToken.getExpiresAt(),
+				Collections.emptySet());
+		// @formatter:on
+		OidcUserRequest userRequest = new OidcUserRequest(this.registration.build(), accessToken, this.idToken);
+		OidcUser oidcUser = this.userService.loadUser(userRequest).block();
+		assertThat(oidcUser).isNotNull();
+		assertThat(oidcUser.getUserInfo()).isNull();
+	}
+
+	@Test
+	public void loadUserWhenCustomRetrieveUserInfoSetThenUsed() {
+		Map<String, Object> attributes = new HashMap<>();
+		attributes.put(StandardClaimNames.SUB, "subject");
+		attributes.put("user", "steve");
+		OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), attributes,
+				"user");
+		given(this.oauth2UserService.loadUser(any())).willReturn(Mono.just(oauth2User));
+		Predicate<OidcUserRequest> customRetrieveUserInfo = mock(Predicate.class);
+		this.userService.setRetrieveUserInfo(customRetrieveUserInfo);
+		given(customRetrieveUserInfo.test(any(OidcUserRequest.class))).willReturn(true);
+		// @formatter:off
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(
+				this.accessToken.getTokenType(),
+				this.accessToken.getTokenValue(),
+				this.accessToken.getIssuedAt(),
+				this.accessToken.getExpiresAt(),
+				Collections.emptySet());
+		// @formatter:on
+		OidcUserRequest userRequest = new OidcUserRequest(this.registration.build(), accessToken, this.idToken);
+		OidcUser oidcUser = this.userService.loadUser(userRequest).block();
+		assertThat(oidcUser).isNotNull();
+		assertThat(oidcUser.getUserInfo()).isNotNull();
+		verify(customRetrieveUserInfo).test(userRequest);
+	}
+
 	@Test
 	public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() {
 		OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService();

+ 35 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java

@@ -23,6 +23,7 @@ import java.util.Iterator;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
+import java.util.function.Predicate;
 
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
@@ -58,6 +59,7 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.same;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
@@ -129,6 +131,15 @@ public class OidcUserServiceTests {
 		this.userService.setAccessibleScopes(Collections.emptySet());
 	}
 
+	@Test
+	public void setRetrieveUserInfoWhenNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.userService.setRetrieveUserInfo(null))
+				.withMessage("retrieveUserInfo cannot be null");
+		// @formatter:on
+	}
+
 	@Test
 	public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
@@ -218,6 +229,30 @@ public class OidcUserServiceTests {
 		assertThat(user.getUserInfo()).isNotNull();
 	}
 
+	@Test
+	public void loadUserWhenCustomRetrieveUserInfoSetThenUsed() {
+		// @formatter:off
+		String userInfoResponse = "{\n"
+				+ "   \"sub\": \"subject1\",\n"
+				+ "   \"name\": \"first last\",\n"
+				+ "   \"given_name\": \"first\",\n"
+				+ "   \"family_name\": \"last\",\n"
+				+ "   \"preferred_username\": \"user1\",\n"
+				+ "   \"email\": \"user1@example.com\"\n"
+				+ "}\n";
+		// @formatter:on
+		this.server.enqueue(jsonResponse(userInfoResponse));
+		String userInfoUri = this.server.url("/user").toString();
+		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
+		this.accessToken = TestOAuth2AccessTokens.noScopes();
+		Predicate<OidcUserRequest> customRetrieveUserInfo = mock(Predicate.class);
+		given(customRetrieveUserInfo.test(any(OidcUserRequest.class))).willReturn(true);
+		this.userService.setRetrieveUserInfo(customRetrieveUserInfo);
+		OidcUser user = this.userService
+			.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
+		assertThat(user.getUserInfo()).isNotNull();
+	}
+
 	@Test
 	public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
 		// @formatter:off