Browse Source

Customize mapping the OidcUser

Closes gh-14672
Steve Riesenberg 1 year ago
parent
commit
e52dd81d03

+ 66 - 24
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java

@@ -18,9 +18,8 @@ package org.springframework.security.oauth2.client.oidc.userinfo;
 
 import java.time.Instant;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.Map;
-import java.util.Set;
+import java.util.function.BiFunction;
 import java.util.function.Function;
 import java.util.function.Predicate;
 
@@ -28,7 +27,6 @@ import reactor.core.publisher.Mono;
 
 import org.springframework.core.convert.TypeDescriptor;
 import org.springframework.core.convert.converter.Converter;
-import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService;
@@ -40,6 +38,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.converter.ClaimConversionService;
 import org.springframework.security.oauth2.core.converter.ClaimTypeConverter;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
 import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
 import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
@@ -47,7 +46,6 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.util.Assert;
-import org.springframework.util.StringUtils;
 
 /**
  * An implementation of an {@link ReactiveOAuth2UserService} that supports OpenID Connect
@@ -75,6 +73,8 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
 
 	private Predicate<OidcUserRequest> retrieveUserInfo = OidcUserRequestUtils::shouldRetrieveUserInfo;
 
+	private BiFunction<OidcUserRequest, OidcUserInfo, Mono<OidcUser>> oidcUserMapper = this::getUser;
+
 	/**
 	 * Returns the default {@link Converter}'s used for type conversion of claim values
 	 * for an {@link OidcUserInfo}.
@@ -103,29 +103,15 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
 		Assert.notNull(userRequest, "userRequest cannot be null");
 		// @formatter:off
 		return getUserInfo(userRequest)
-				.map((userInfo) ->
-						new OidcUserAuthority(userRequest.getIdToken(), userInfo)
-				)
-				.defaultIfEmpty(new OidcUserAuthority(userRequest.getIdToken(), null))
-				.map((authority) -> {
-					OidcUserInfo userInfo = authority.getUserInfo();
-					Set<GrantedAuthority> authorities = new HashSet<>();
-					authorities.add(authority);
-					OAuth2AccessToken token = userRequest.getAccessToken();
-					for (String scope : token.getScopes()) {
-						authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope));
-					}
-					String userNameAttributeName = userRequest.getClientRegistration().getProviderDetails()
-							.getUserInfoEndpoint().getUserNameAttributeName();
-					if (StringUtils.hasText(userNameAttributeName)) {
-						return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo,
-								userNameAttributeName);
-					}
-					return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo);
-				});
+				.flatMap((userInfo) -> this.oidcUserMapper.apply(userRequest, userInfo))
+				.switchIfEmpty(Mono.defer(() -> this.oidcUserMapper.apply(userRequest, null)));
 		// @formatter:on
 	}
 
+	private Mono<OidcUser> getUser(OidcUserRequest userRequest, OidcUserInfo userInfo) {
+		return Mono.just(OidcUserRequestUtils.getUser(userRequest, userInfo));
+	}
+
 	private Mono<OidcUserInfo> getUserInfo(OidcUserRequest userRequest) {
 		if (!this.retrieveUserInfo.test(userRequest)) {
 			return Mono.empty();
@@ -193,4 +179,60 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
 		this.retrieveUserInfo = retrieveUserInfo;
 	}
 
+	/**
+	 * Sets the {@code BiFunction} used to map the {@link OidcUser user} from the
+	 * {@link OidcUserRequest user request} and {@link OidcUserInfo user info}.
+	 * <p>
+	 * This is useful when you need to map the user or authorities from the access token
+	 * itself. For example, when the authorization server provides authorization
+	 * information in the access token payload you can do the following: <pre>
+	 * 	&#64;Bean
+	 * 	public OidcReactiveOAuth2UserService oidcUserService() {
+	 * 		var userService = new OidcReactiveOAuth2UserService();
+	 * 		userService.setOidcUserMapper(oidcUserMapper());
+	 * 		return userService;
+	 * 	}
+	 *
+	 * 	private static BiFunction&lt;OidcUserRequest, OidcUserInfo, Mono&lt;OidcUser&gt;&gt; oidcUserMapper() {
+	 * 		return (userRequest, userInfo) -> {
+	 * 			var accessToken = userRequest.getAccessToken();
+	 * 			var grantedAuthorities = new HashSet&lt;GrantedAuthority&gt;();
+	 * 			// TODO: Map authorities from the access token
+	 * 			var userNameAttributeName = "preferred_username";
+	 * 			return Mono.just(new DefaultOidcUser(
+	 * 				grantedAuthorities,
+	 * 				userRequest.getIdToken(),
+	 * 				userInfo,
+	 * 				userNameAttributeName
+	 * 			));
+	 * 		};
+	 * 	}
+	 * </pre>
+	 * <p>
+	 * Note that you can access the {@code userNameAttributeName} via the
+	 * {@link ClientRegistration} as follows: <pre>
+	 * 	var userNameAttributeName = userRequest.getClientRegistration()
+	 * 		.getProviderDetails()
+	 * 		.getUserInfoEndpoint()
+	 * 		.getUserNameAttributeName();
+	 * </pre>
+	 * <p>
+	 * By default, a {@link DefaultOidcUser} is created with authorities mapped as
+	 * follows:
+	 * <ul>
+	 * <li>An {@link OidcUserAuthority} is created from the {@link OidcIdToken} and
+	 * {@link OidcUserInfo} with an authority of {@code OIDC_USER}</li>
+	 * <li>Additional {@link SimpleGrantedAuthority authorities} are mapped from the
+	 * {@link OAuth2AccessToken#getScopes() access token scopes} with a prefix of
+	 * {@code SCOPE_}</li>
+	 * </ul>
+	 * @param oidcUserMapper the function used to map the {@link OidcUser} from the
+	 * {@link OidcUserRequest} and {@link OidcUserInfo}
+	 * @since 6.3
+	 */
+	public final void setOidcUserMapper(BiFunction<OidcUserRequest, OidcUserInfo, Mono<OidcUser>> oidcUserMapper) {
+		Assert.notNull(oidcUserMapper, "oidcUserMapper cannot be null");
+		this.oidcUserMapper = oidcUserMapper;
+	}
+
 }

+ 25 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserRequestUtils.java

@@ -16,8 +16,18 @@
 
 package org.springframework.security.oauth2.client.oidc.userinfo;
 
+import java.util.LinkedHashSet;
+import java.util.Set;
+
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
+import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
+import org.springframework.security.oauth2.core.oidc.user.OidcUser;
+import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
 import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
 
@@ -66,6 +76,21 @@ final class OidcUserRequestUtils {
 		return false;
 	}
 
+	static OidcUser getUser(OidcUserRequest userRequest, OidcUserInfo userInfo) {
+		Set<GrantedAuthority> authorities = new LinkedHashSet<>();
+		authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo));
+		OAuth2AccessToken token = userRequest.getAccessToken();
+		for (String scope : token.getScopes()) {
+			authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope));
+		}
+		ClientRegistration.ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails();
+		String userNameAttributeName = providerDetails.getUserInfoEndpoint().getUserNameAttributeName();
+		if (StringUtils.hasText(userNameAttributeName)) {
+			return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, userNameAttributeName);
+		}
+		return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo);
+	}
+
 	private OidcUserRequestUtils() {
 	}
 

+ 62 - 18
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java

@@ -20,15 +20,14 @@ import java.time.Instant;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.BiFunction;
 import java.util.function.Function;
 import java.util.function.Predicate;
 
 import org.springframework.core.convert.TypeDescriptor;
 import org.springframework.core.convert.converter.Converter;
-import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistration.ProviderDetails;
@@ -41,6 +40,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.converter.ClaimConversionService;
 import org.springframework.security.oauth2.core.converter.ClaimTypeConverter;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
 import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
@@ -57,6 +57,7 @@ import org.springframework.util.StringUtils;
  * Provider's.
  *
  * @author Joe Grandja
+ * @author Steve Riesenberg
  * @since 5.0
  * @see OAuth2UserService
  * @see OidcUserRequest
@@ -81,6 +82,8 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 
 	private Predicate<OidcUserRequest> retrieveUserInfo = this::shouldRetrieveUserInfo;
 
+	private BiFunction<OidcUserRequest, OidcUserInfo, OidcUser> oidcUserMapper = OidcUserRequestUtils::getUser;
+
 	/**
 	 * Returns the default {@link Converter}'s used for type conversion of claim values
 	 * for an {@link OidcUserInfo}.
@@ -130,13 +133,7 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 				throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 			}
 		}
-		Set<GrantedAuthority> authorities = new LinkedHashSet<>();
-		authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo));
-		OAuth2AccessToken token = userRequest.getAccessToken();
-		for (String authority : token.getScopes()) {
-			authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
-		}
-		return getUser(userRequest, userInfo, authorities);
+		return this.oidcUserMapper.apply(userRequest, userInfo);
 	}
 
 	private Map<String, Object> getClaims(OidcUserRequest userRequest, OAuth2User oauth2User) {
@@ -148,15 +145,6 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 		return DEFAULT_CLAIM_TYPE_CONVERTER.convert(oauth2User.getAttributes());
 	}
 
-	private OidcUser getUser(OidcUserRequest userRequest, OidcUserInfo userInfo, Set<GrantedAuthority> authorities) {
-		ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails();
-		String userNameAttributeName = providerDetails.getUserInfoEndpoint().getUserNameAttributeName();
-		if (StringUtils.hasText(userNameAttributeName)) {
-			return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo, userNameAttributeName);
-		}
-		return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo);
-	}
-
 	private boolean shouldRetrieveUserInfo(OidcUserRequest userRequest) {
 		// Auto-disabled if UserInfo Endpoint URI is not provided
 		ProviderDetails providerDetails = userRequest.getClientRegistration().getProviderDetails();
@@ -255,4 +243,60 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 		this.retrieveUserInfo = retrieveUserInfo;
 	}
 
+	/**
+	 * Sets the {@code BiFunction} used to map the {@link OidcUser user} from the
+	 * {@link OidcUserRequest user request} and {@link OidcUserInfo user info}.
+	 * <p>
+	 * This is useful when you need to map the user or authorities from the access token
+	 * itself. For example, when the authorization server provides authorization
+	 * information in the access token payload you can do the following: <pre>
+	 * 	&#64;Bean
+	 * 	public OidcUserService oidcUserService() {
+	 * 		var userService = new OidcUserService();
+	 * 		userService.setOidcUserMapper(oidcUserMapper());
+	 * 		return userService;
+	 * 	}
+	 *
+	 * 	private static BiFunction&lt;OidcUserRequest, OidcUserInfo, OidcUser&gt; oidcUserMapper() {
+	 * 		return (userRequest, userInfo) -> {
+	 * 			var accessToken = userRequest.getAccessToken();
+	 * 			var grantedAuthorities = new HashSet&lt;GrantedAuthority&gt;();
+	 * 			// TODO: Map authorities from the access token
+	 * 			var userNameAttributeName = "preferred_username";
+	 * 			return new DefaultOidcUser(
+	 * 				grantedAuthorities,
+	 * 				userRequest.getIdToken(),
+	 * 				userInfo,
+	 * 				userNameAttributeName
+	 * 			);
+	 * 		};
+	 * 	}
+	 * </pre>
+	 * <p>
+	 * Note that you can access the {@code userNameAttributeName} via the
+	 * {@link ClientRegistration} as follows: <pre>
+	 * 	var userNameAttributeName = userRequest.getClientRegistration()
+	 * 		.getProviderDetails()
+	 * 		.getUserInfoEndpoint()
+	 * 		.getUserNameAttributeName();
+	 * </pre>
+	 * <p>
+	 * By default, a {@link DefaultOidcUser} is created with authorities mapped as
+	 * follows:
+	 * <ul>
+	 * <li>An {@link OidcUserAuthority} is created from the {@link OidcIdToken} and
+	 * {@link OidcUserInfo} with an authority of {@code OIDC_USER}</li>
+	 * <li>Additional {@link SimpleGrantedAuthority authorities} are mapped from the
+	 * {@link OAuth2AccessToken#getScopes() access token scopes} with a prefix of
+	 * {@code SCOPE_}</li>
+	 * </ul>
+	 * @param oidcUserMapper the function used to map the {@link OidcUser} from the
+	 * {@link OidcUserRequest} and {@link OidcUserInfo}
+	 * @since 6.3
+	 */
+	public final void setOidcUserMapper(BiFunction<OidcUserRequest, OidcUserInfo, OidcUser> oidcUserMapper) {
+		Assert.notNull(oidcUserMapper, "oidcUserMapper cannot be null");
+		this.oidcUserMapper = oidcUserMapper;
+	}
+
 }

+ 53 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java

@@ -23,6 +23,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
+import java.util.function.BiFunction;
 import java.util.function.Function;
 import java.util.function.Predicate;
 
@@ -31,6 +32,7 @@ import okhttp3.mockwebserver.MockWebServer;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.junit.jupiter.MockitoExtension;
 import reactor.core.publisher.Mono;
@@ -53,8 +55,10 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.converter.ClaimTypeConverter;
 import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
 import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
 import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens;
+import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2User;
@@ -64,6 +68,8 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.isNull;
 import static org.mockito.ArgumentMatchers.same;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
@@ -235,6 +241,53 @@ public class OidcReactiveOAuth2UserServiceTests {
 		verify(customRetrieveUserInfo).test(userRequest);
 	}
 
+	@Test
+	public void loadUserWhenCustomOidcUserMapperSetThenUsed() {
+		Map<String, Object> attributes = new HashMap<>();
+		attributes.put(StandardClaimNames.SUB, "subject");
+		attributes.put("user", "steve");
+		OAuth2User oauth2User = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), attributes,
+				"user");
+		given(this.oauth2UserService.loadUser(any(OidcUserRequest.class))).willReturn(Mono.just(oauth2User));
+		BiFunction<OidcUserRequest, OidcUserInfo, Mono<OidcUser>> customOidcUserMapper = mock(BiFunction.class);
+		OidcUser actualUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("a", "b"), this.idToken,
+				IdTokenClaimNames.SUB);
+		given(customOidcUserMapper.apply(any(OidcUserRequest.class), any(OidcUserInfo.class)))
+			.willReturn(Mono.just(actualUser));
+		this.userService.setOidcUserMapper(customOidcUserMapper);
+		OidcUserRequest userRequest = userRequest();
+		OidcUser oidcUser = this.userService.loadUser(userRequest).block();
+		assertThat(oidcUser).isNotNull();
+		assertThat(oidcUser).isEqualTo(actualUser);
+		ArgumentCaptor<OidcUserInfo> userInfoCaptor = ArgumentCaptor.forClass(OidcUserInfo.class);
+		verify(customOidcUserMapper).apply(eq(userRequest), userInfoCaptor.capture());
+		OidcUserInfo userInfo = userInfoCaptor.getValue();
+		assertThat(userInfo.getSubject()).isEqualTo("subject");
+		assertThat(userInfo.getClaimAsString("user")).isEqualTo("steve");
+	}
+
+	@Test
+	public void loadUserWhenCustomOidcUserMapperSetAndUserInfoNotRetrievedThenUsed() {
+		// @formatter:off
+		this.accessToken = new OAuth2AccessToken(
+				this.accessToken.getTokenType(),
+				this.accessToken.getTokenValue(),
+				this.accessToken.getIssuedAt(),
+				this.accessToken.getExpiresAt(),
+				Collections.emptySet());
+		// @formatter:on
+		BiFunction<OidcUserRequest, OidcUserInfo, Mono<OidcUser>> customOidcUserMapper = mock(BiFunction.class);
+		OidcUser actualUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("a", "b"), this.idToken,
+				IdTokenClaimNames.SUB);
+		given(customOidcUserMapper.apply(any(OidcUserRequest.class), isNull())).willReturn(Mono.just(actualUser));
+		this.userService.setOidcUserMapper(customOidcUserMapper);
+		OidcUserRequest userRequest = userRequest();
+		OidcUser oidcUser = this.userService.loadUser(userRequest).block();
+		assertThat(oidcUser).isNotNull();
+		assertThat(oidcUser).isEqualTo(actualUser);
+		verify(customOidcUserMapper).apply(eq(userRequest), isNull(OidcUserInfo.class));
+	}
+
 	@Test
 	public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() {
 		OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService();

+ 46 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java

@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
+import java.util.function.BiFunction;
 import java.util.function.Function;
 import java.util.function.Predicate;
 
@@ -31,12 +32,14 @@ import okhttp3.mockwebserver.RecordedRequest;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
 import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
@@ -49,8 +52,10 @@ import org.springframework.security.oauth2.core.converter.ClaimTypeConverter;
 import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
+import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
 import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
 import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens;
+import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
 import org.springframework.security.oauth2.core.user.OAuth2User;
@@ -60,6 +65,7 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.same;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
@@ -140,6 +146,15 @@ public class OidcUserServiceTests {
 		// @formatter:on
 	}
 
+	@Test
+	public void setOidcUserMapperWhenNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.userService.setOidcUserMapper(null))
+				.withMessage("oidcUserMapper cannot be null");
+		// @formatter:on
+	}
+
 	@Test
 	public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
@@ -253,6 +268,37 @@ public class OidcUserServiceTests {
 		assertThat(user.getUserInfo()).isNotNull();
 	}
 
+	@Test
+	public void loadUserWhenCustomOidcUserMapperSetThenUsed() {
+		// @formatter:off
+		String userInfoResponse = "{\n"
+				+ "   \"sub\": \"subject1\",\n"
+				+ "   \"name\": \"first last\",\n"
+				+ "   \"given_name\": \"first\",\n"
+				+ "   \"family_name\": \"last\",\n"
+				+ "   \"preferred_username\": \"user1\",\n"
+				+ "   \"email\": \"user1@example.com\"\n"
+				+ "}\n";
+		// @formatter:on
+		this.server.enqueue(jsonResponse(userInfoResponse));
+		String userInfoUri = this.server.url("/user").toString();
+		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
+		this.accessToken = TestOAuth2AccessTokens.noScopes();
+		BiFunction<OidcUserRequest, OidcUserInfo, OidcUser> customOidcUserMapper = mock(BiFunction.class);
+		OidcUser actualUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("a", "b"), this.idToken,
+				IdTokenClaimNames.SUB);
+		given(customOidcUserMapper.apply(any(OidcUserRequest.class), any(OidcUserInfo.class))).willReturn(actualUser);
+		this.userService.setOidcUserMapper(customOidcUserMapper);
+		OidcUserRequest userRequest = new OidcUserRequest(clientRegistration, this.accessToken, this.idToken);
+		OidcUser user = this.userService.loadUser(userRequest);
+		assertThat(user).isEqualTo(actualUser);
+		ArgumentCaptor<OidcUserInfo> userInfoCaptor = ArgumentCaptor.forClass(OidcUserInfo.class);
+		verify(customOidcUserMapper).apply(eq(userRequest), userInfoCaptor.capture());
+		OidcUserInfo userInfo = userInfoCaptor.getValue();
+		assertThat(userInfo.getSubject()).isEqualTo("subject1");
+		assertThat(userInfo.getClaimAsString("preferred_username")).isEqualTo("user1");
+	}
+
 	@Test
 	public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
 		// @formatter:off