Quellcode durchsuchen

Add OAuth2User to OidcUser Conversion Params

Previously the Oidc(Reactive)OAuth2UserService APIs allowed a strategy
for converting to the OidcUser with the OidcUserRequest and OidcUserInfo.
The input should also include the OAuth2User to make
it simple to use the OAuth2User as a part of the conversion.

This commit introduces OidcUserSource as a POJO containing
OidcUserRequest, OidcUserInfo, and OAuth2User.

It then updates the OidcUser conversion strategy in OidcUserService and
OidcReactiveOAuth2UserService to accept OidcUserSource as the source for
the Converter used to create OidUser.

Closes gh-17626
Rob Winch vor 1 Monat
Ursprung
Commit
bf877a9864

+ 31 - 27
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java

@@ -73,7 +73,8 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
 
 	private Predicate<OidcUserRequest> retrieveUserInfo = OidcUserRequestUtils::shouldRetrieveUserInfo;
 
-	private BiFunction<OidcUserRequest, OidcUserInfo, Mono<OidcUser>> oidcUserMapper = this::getUser;
+	private Converter<OidcUserSource, Mono<OidcUser>> oidcUserConverter = (source) -> Mono
+		.just(OidcUserRequestUtils.getUser(source));
 
 	/**
 	 * Returns the default {@link Converter}'s used for type conversion of claim values
@@ -102,34 +103,26 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
 	public Mono<OidcUser> loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
 		Assert.notNull(userRequest, "userRequest cannot be null");
 		// @formatter:off
-		return getUserInfo(userRequest)
-				.flatMap((userInfo) -> this.oidcUserMapper.apply(userRequest, userInfo))
-				.switchIfEmpty(Mono.defer(() -> this.oidcUserMapper.apply(userRequest, null)));
+		return Mono.just(userRequest)
+			.filter(this.retrieveUserInfo::test)
+			.flatMap(this.oauth2UserService::loadUser)
+			.flatMap((oauth2User) -> toOidcUser(userRequest, oauth2User))
+			.switchIfEmpty(Mono.defer(() -> this.oidcUserConverter.convert(new OidcUserSource(userRequest))));
 		// @formatter:on
 	}
 
-	private Mono<OidcUser> getUser(OidcUserRequest userRequest, OidcUserInfo userInfo) {
-		return Mono.just(OidcUserRequestUtils.getUser(userRequest, userInfo));
-	}
-
-	private Mono<OidcUserInfo> getUserInfo(OidcUserRequest userRequest) {
-		if (!this.retrieveUserInfo.test(userRequest)) {
-			return Mono.empty();
-		}
-		// @formatter:off
-		return this.oauth2UserService
-				.loadUser(userRequest)
-				.map(OAuth2User::getAttributes)
-				.map((claims) -> convertClaims(claims, userRequest.getClientRegistration()))
-				.map(OidcUserInfo::new)
-				.doOnNext((userInfo) -> {
-					String subject = userInfo.getSubject();
-					if (subject == null || !subject.equals(userRequest.getIdToken().getSubject())) {
-						OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE);
-						throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
-					}
-				});
-		// @formatter:on
+	private Mono<OidcUser> toOidcUser(OidcUserRequest userRequest, OAuth2User oauth2User) {
+		return Mono.defer(() -> {
+			Map<String, Object> claims = convertClaims(oauth2User.getAttributes(), userRequest.getClientRegistration());
+			OidcUserInfo userInfo = new OidcUserInfo(claims);
+			String subject = userInfo.getSubject();
+			if (subject == null || !subject.equals(userRequest.getIdToken().getSubject())) {
+				OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE);
+				throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+			}
+			OidcUserSource source = new OidcUserSource(userRequest, userInfo, oauth2User);
+			return this.oidcUserConverter.convert(source);
+		});
 	}
 
 	private Map<String, Object> convertClaims(Map<String, Object> claims, ClientRegistration clientRegistration) {
@@ -229,10 +222,21 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
 	 * @param oidcUserMapper the function used to map the {@link OidcUser} from the
 	 * {@link OidcUserRequest} and {@link OidcUserInfo}
 	 * @since 6.3
+	 * @deprecated Use {@link #setOidcUserConverter(Converter)} instead
 	 */
+	@Deprecated(since = "7.0", forRemoval = true)
 	public final void setOidcUserMapper(BiFunction<OidcUserRequest, OidcUserInfo, Mono<OidcUser>> oidcUserMapper) {
 		Assert.notNull(oidcUserMapper, "oidcUserMapper cannot be null");
-		this.oidcUserMapper = oidcUserMapper;
+		this.oidcUserConverter = (source) -> oidcUserMapper.apply(source.getUserRequest(), source.getUserInfo());
+	}
+
+	/**
+	 * Allows converting from the {@link OidcUserSource} to and {@link OidcUser}.
+	 * @param oidcUserConverter the {@link Converter} to use. Cannot be null.
+	 */
+	public void setOidcUserConverter(Converter<OidcUserSource, Mono<OidcUser>> oidcUserConverter) {
+		Assert.notNull(oidcUserConverter, "oidcUserConverter cannot be null");
+		this.oidcUserConverter = oidcUserConverter;
 	}
 
 }

+ 3 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java

@@ -76,7 +76,9 @@ final class OidcUserRequestUtils {
 		return false;
 	}
 
-	static OidcUser getUser(OidcUserRequest userRequest, OidcUserInfo userInfo) {
+	static OidcUser getUser(OidcUserSource userMetadata) {
+		OidcUserRequest userRequest = userMetadata.getUserRequest();
+		OidcUserInfo userInfo = userMetadata.getUserInfo();
 		Set<GrantedAuthority> authorities = new LinkedHashSet<>();
 		ClientRegistration.ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails();
 		String userNameAttributeName = providerDetails.getUserInfoEndpoint().getUserNameAttributeName();

+ 17 - 4
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java

@@ -82,7 +82,7 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 
 	private Predicate<OidcUserRequest> retrieveUserInfo = this::shouldRetrieveUserInfo;
 
-	private BiFunction<OidcUserRequest, OidcUserInfo, OidcUser> oidcUserMapper = OidcUserRequestUtils::getUser;
+	private Converter<OidcUserSource, OidcUser> oidcUserConverter = OidcUserRequestUtils::getUser;
 
 	/**
 	 * Returns the default {@link Converter}'s used for type conversion of claim values
@@ -111,8 +111,9 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 	public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
 		Assert.notNull(userRequest, "userRequest cannot be null");
 		OidcUserInfo userInfo = null;
+		OAuth2User oauth2User = null;
 		if (this.retrieveUserInfo.test(userRequest)) {
-			OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest);
+			oauth2User = this.oauth2UserService.loadUser(userRequest);
 			Map<String, Object> claims = getClaims(userRequest, oauth2User);
 			userInfo = new OidcUserInfo(claims);
 			// https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
@@ -133,7 +134,8 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 				throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 			}
 		}
-		return this.oidcUserMapper.apply(userRequest, userInfo);
+		OidcUserSource source = new OidcUserSource(userRequest, userInfo, oauth2User);
+		return this.oidcUserConverter.convert(source);
 	}
 
 	private Map<String, Object> getClaims(OidcUserRequest userRequest, OAuth2User oauth2User) {
@@ -293,10 +295,21 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 	 * @param oidcUserMapper the function used to map the {@link OidcUser} from the
 	 * {@link OidcUserRequest} and {@link OidcUserInfo}
 	 * @since 6.3
+	 * @deprecated Use {@link #setOidcUserConverter(Converter)} instead
 	 */
+	@Deprecated(since = "7.0", forRemoval = true)
 	public final void setOidcUserMapper(BiFunction<OidcUserRequest, OidcUserInfo, OidcUser> oidcUserMapper) {
 		Assert.notNull(oidcUserMapper, "oidcUserMapper cannot be null");
-		this.oidcUserMapper = oidcUserMapper;
+		this.oidcUserConverter = (source) -> oidcUserMapper.apply(source.getUserRequest(), source.getUserInfo());
+	}
+
+	/**
+	 * Allows converting from the {@link OidcUserSource} to and {@link OidcUser}.
+	 * @param oidcUserConverter the {@link Converter} to use. Cannot be null.
+	 */
+	public void setOidcUserConverter(Converter<OidcUserSource, OidcUser> oidcUserConverter) {
+		Assert.notNull(oidcUserConverter, "oidcUserConverter cannot be null");
+		this.oidcUserConverter = oidcUserConverter;
 	}
 
 }

+ 64 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserSource.java

@@ -0,0 +1,64 @@
+/*
+ * Copyright 2002-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.oidc.userinfo;
+
+import org.jspecify.annotations.Nullable;
+
+import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+import org.springframework.util.Assert;
+
+/**
+ * The source for the converter to
+ * {@link org.springframework.security.oauth2.core.oidc.user.OidcUser}.
+ *
+ * @author Rob Winch
+ * @since 7.0
+ */
+public class OidcUserSource {
+
+	private final OidcUserRequest userRequest;
+
+	private final @Nullable OidcUserInfo userInfo;
+
+	private final @Nullable OAuth2User oauth2User;
+
+	public OidcUserSource(OidcUserRequest userRequest) {
+		this(userRequest, null, null);
+	}
+
+	public OidcUserSource(OidcUserRequest userRequest, @Nullable OidcUserInfo userInfo,
+			@Nullable OAuth2User oauth2User) {
+		Assert.notNull(userRequest, "userRequest cannot be null");
+		this.userRequest = userRequest;
+		this.userInfo = userInfo;
+		this.oauth2User = oauth2User;
+	}
+
+	public OidcUserRequest getUserRequest() {
+		return this.userRequest;
+	}
+
+	public @Nullable OidcUserInfo getUserInfo() {
+		return this.userInfo;
+	}
+
+	public @Nullable OAuth2User getOauth2User() {
+		return this.oauth2User;
+	}
+
+}

+ 30 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java

@@ -316,6 +316,36 @@ public class OidcReactiveOAuth2UserServiceTests {
 		assertThat(userAuthority.getUserNameAttributeName()).isEqualTo("id");
 	}
 
+	@Test
+	public void loadUserWhenCustomOidcUserConverterSetThenUsed() {
+		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
+			.userInfoUri("https://example.com/user")
+			.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
+			.userNameAttributeName(StandardClaimNames.SUB)
+			.build();
+		this.accessToken = TestOAuth2AccessTokens.scopes(clientRegistration.getScopes().toArray(new String[0]));
+		Converter<OidcUserSource, Mono<OidcUser>> oidcUserConverter = mock(Converter.class);
+		String nameAttributeKey = IdTokenClaimNames.SUB;
+		OidcUser actualUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("a", "b"), this.idToken,
+				nameAttributeKey);
+		OAuth2User oauth2User = new DefaultOAuth2User(actualUser.getAuthorities(), actualUser.getClaims(),
+				nameAttributeKey);
+		ReactiveOAuth2UserService<OAuth2UserRequest, OAuth2User> oauth2 = mock(ReactiveOAuth2UserService.class);
+		given(oauth2.loadUser(any())).willReturn(Mono.just(oauth2User));
+		given(oidcUserConverter.convert(any())).willReturn(Mono.just(actualUser));
+		this.userService.setOauth2UserService(oauth2);
+		this.userService.setOidcUserConverter(oidcUserConverter);
+		OidcUserRequest userRequest = new OidcUserRequest(clientRegistration, this.accessToken, this.idToken);
+		OidcUser user = this.userService.loadUser(userRequest).block();
+		assertThat(user).isEqualTo(actualUser);
+		ArgumentCaptor<OidcUserSource> metadataCptr = ArgumentCaptor.forClass(OidcUserSource.class);
+		verify(oidcUserConverter).convert(metadataCptr.capture());
+		OidcUserSource metadata = metadataCptr.getValue();
+		assertThat(metadata.getUserRequest()).isEqualTo(userRequest);
+		assertThat(metadata.getOauth2User()).isEqualTo(oauth2User);
+		assertThat(metadata.getUserInfo()).isNotNull();
+	}
+
 	@Test
 	public void loadUserWhenNestedUserInfoSuccessThenReturnUser() throws IOException {
 		// @formatter:off

+ 39 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java

@@ -44,6 +44,8 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
+import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
+import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
 import org.springframework.security.oauth2.core.AuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
@@ -58,6 +60,7 @@ import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens;
 import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
+import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
 
@@ -155,6 +158,15 @@ public class OidcUserServiceTests {
 		// @formatter:on
 	}
 
+	@Test
+	public void setOidcUserConverterWhenNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.userService.setOidcUserConverter(null))
+				.withMessage("oidcUserConverter cannot be null");
+		// @formatter:on
+	}
+
 	@Test
 	public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
@@ -299,6 +311,33 @@ public class OidcUserServiceTests {
 		assertThat(userInfo.getClaimAsString("preferred_username")).isEqualTo("user1");
 	}
 
+	@Test
+	public void loadUserWhenCustomOidcUserConverterSetThenUsed() {
+		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri("https://example.com/user")
+			.build();
+		this.accessToken = TestOAuth2AccessTokens.noScopes();
+		Converter<OidcUserSource, OidcUser> oidcUserConverter = mock(Converter.class);
+		String nameAttributeKey = IdTokenClaimNames.SUB;
+		OidcUser actualUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("a", "b"), this.idToken,
+				nameAttributeKey);
+		OAuth2User oauth2User = new DefaultOAuth2User(actualUser.getAuthorities(), actualUser.getClaims(),
+				nameAttributeKey);
+		OAuth2UserService<OAuth2UserRequest, OAuth2User> oauth2 = mock(OAuth2UserService.class);
+		given(oauth2.loadUser(any())).willReturn(oauth2User);
+		given(oidcUserConverter.convert(any())).willReturn(actualUser);
+		this.userService.setOauth2UserService(oauth2);
+		this.userService.setOidcUserConverter(oidcUserConverter);
+		OidcUserRequest userRequest = new OidcUserRequest(clientRegistration, this.accessToken, this.idToken);
+		OidcUser user = this.userService.loadUser(userRequest);
+		assertThat(user).isEqualTo(actualUser);
+		ArgumentCaptor<OidcUserSource> metadataCptr = ArgumentCaptor.forClass(OidcUserSource.class);
+		verify(oidcUserConverter).convert(metadataCptr.capture());
+		OidcUserSource metadata = metadataCptr.getValue();
+		assertThat(metadata.getUserRequest()).isEqualTo(userRequest);
+		assertThat(metadata.getOauth2User()).isEqualTo(oauth2User);
+		assertThat(metadata.getUserInfo()).isNotNull();
+	}
+
 	@Test
 	public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
 		// @formatter:off