|
@@ -73,7 +73,8 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
|
|
|
|
|
|
private Predicate<OidcUserRequest> retrieveUserInfo = OidcUserRequestUtils::shouldRetrieveUserInfo;
|
|
private Predicate<OidcUserRequest> retrieveUserInfo = OidcUserRequestUtils::shouldRetrieveUserInfo;
|
|
|
|
|
|
- private BiFunction<OidcUserRequest, OidcUserInfo, Mono<OidcUser>> oidcUserMapper = this::getUser;
|
|
|
|
|
|
+ private Converter<OidcUserSource, Mono<OidcUser>> oidcUserConverter = (source) -> Mono
|
|
|
|
+ .just(OidcUserRequestUtils.getUser(source));
|
|
|
|
|
|
/**
|
|
/**
|
|
* Returns the default {@link Converter}'s used for type conversion of claim values
|
|
* Returns the default {@link Converter}'s used for type conversion of claim values
|
|
@@ -102,34 +103,26 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
|
|
public Mono<OidcUser> loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
|
|
public Mono<OidcUser> loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
|
|
Assert.notNull(userRequest, "userRequest cannot be null");
|
|
Assert.notNull(userRequest, "userRequest cannot be null");
|
|
// @formatter:off
|
|
// @formatter:off
|
|
- return getUserInfo(userRequest)
|
|
|
|
- .flatMap((userInfo) -> this.oidcUserMapper.apply(userRequest, userInfo))
|
|
|
|
- .switchIfEmpty(Mono.defer(() -> this.oidcUserMapper.apply(userRequest, null)));
|
|
|
|
|
|
+ return Mono.just(userRequest)
|
|
|
|
+ .filter(this.retrieveUserInfo::test)
|
|
|
|
+ .flatMap(this.oauth2UserService::loadUser)
|
|
|
|
+ .flatMap((oauth2User) -> toOidcUser(userRequest, oauth2User))
|
|
|
|
+ .switchIfEmpty(Mono.defer(() -> this.oidcUserConverter.convert(new OidcUserSource(userRequest))));
|
|
// @formatter:on
|
|
// @formatter:on
|
|
}
|
|
}
|
|
|
|
|
|
- private Mono<OidcUser> getUser(OidcUserRequest userRequest, OidcUserInfo userInfo) {
|
|
|
|
- return Mono.just(OidcUserRequestUtils.getUser(userRequest, userInfo));
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- private Mono<OidcUserInfo> getUserInfo(OidcUserRequest userRequest) {
|
|
|
|
- if (!this.retrieveUserInfo.test(userRequest)) {
|
|
|
|
- return Mono.empty();
|
|
|
|
- }
|
|
|
|
- // @formatter:off
|
|
|
|
- return this.oauth2UserService
|
|
|
|
- .loadUser(userRequest)
|
|
|
|
- .map(OAuth2User::getAttributes)
|
|
|
|
- .map((claims) -> convertClaims(claims, userRequest.getClientRegistration()))
|
|
|
|
- .map(OidcUserInfo::new)
|
|
|
|
- .doOnNext((userInfo) -> {
|
|
|
|
- String subject = userInfo.getSubject();
|
|
|
|
- if (subject == null || !subject.equals(userRequest.getIdToken().getSubject())) {
|
|
|
|
- OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE);
|
|
|
|
- throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
|
|
|
|
- }
|
|
|
|
- });
|
|
|
|
- // @formatter:on
|
|
|
|
|
|
+ private Mono<OidcUser> toOidcUser(OidcUserRequest userRequest, OAuth2User oauth2User) {
|
|
|
|
+ return Mono.defer(() -> {
|
|
|
|
+ Map<String, Object> claims = convertClaims(oauth2User.getAttributes(), userRequest.getClientRegistration());
|
|
|
|
+ OidcUserInfo userInfo = new OidcUserInfo(claims);
|
|
|
|
+ String subject = userInfo.getSubject();
|
|
|
|
+ if (subject == null || !subject.equals(userRequest.getIdToken().getSubject())) {
|
|
|
|
+ OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE);
|
|
|
|
+ throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
|
|
|
|
+ }
|
|
|
|
+ OidcUserSource source = new OidcUserSource(userRequest, userInfo, oauth2User);
|
|
|
|
+ return this.oidcUserConverter.convert(source);
|
|
|
|
+ });
|
|
}
|
|
}
|
|
|
|
|
|
private Map<String, Object> convertClaims(Map<String, Object> claims, ClientRegistration clientRegistration) {
|
|
private Map<String, Object> convertClaims(Map<String, Object> claims, ClientRegistration clientRegistration) {
|
|
@@ -229,10 +222,21 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
|
|
* @param oidcUserMapper the function used to map the {@link OidcUser} from the
|
|
* @param oidcUserMapper the function used to map the {@link OidcUser} from the
|
|
* {@link OidcUserRequest} and {@link OidcUserInfo}
|
|
* {@link OidcUserRequest} and {@link OidcUserInfo}
|
|
* @since 6.3
|
|
* @since 6.3
|
|
|
|
+ * @deprecated Use {@link #setOidcUserConverter(Converter)} instead
|
|
*/
|
|
*/
|
|
|
|
+ @Deprecated(since = "7.0", forRemoval = true)
|
|
public final void setOidcUserMapper(BiFunction<OidcUserRequest, OidcUserInfo, Mono<OidcUser>> oidcUserMapper) {
|
|
public final void setOidcUserMapper(BiFunction<OidcUserRequest, OidcUserInfo, Mono<OidcUser>> oidcUserMapper) {
|
|
Assert.notNull(oidcUserMapper, "oidcUserMapper cannot be null");
|
|
Assert.notNull(oidcUserMapper, "oidcUserMapper cannot be null");
|
|
- this.oidcUserMapper = oidcUserMapper;
|
|
|
|
|
|
+ this.oidcUserConverter = (source) -> oidcUserMapper.apply(source.getUserRequest(), source.getUserInfo());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /**
|
|
|
|
+ * Allows converting from the {@link OidcUserSource} to and {@link OidcUser}.
|
|
|
|
+ * @param oidcUserConverter the {@link Converter} to use. Cannot be null.
|
|
|
|
+ */
|
|
|
|
+ public void setOidcUserConverter(Converter<OidcUserSource, Mono<OidcUser>> oidcUserConverter) {
|
|
|
|
+ Assert.notNull(oidcUserConverter, "oidcUserConverter cannot be null");
|
|
|
|
+ this.oidcUserConverter = oidcUserConverter;
|
|
}
|
|
}
|
|
|
|
|
|
}
|
|
}
|