瀏覽代碼

Add Authorities from Access Token

Josh Cummings 6 年之前
父節點
當前提交
833bfd0c22

+ 6 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java

@@ -18,10 +18,12 @@ package org.springframework.security.oauth2.client.oidc.userinfo;
 import org.springframework.core.convert.TypeDescriptor;
 import org.springframework.core.convert.TypeDescriptor;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService;
 import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
 import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
 import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.converter.ClaimConversionService;
 import org.springframework.security.oauth2.core.converter.ClaimConversionService;
@@ -99,6 +101,10 @@ public class OidcReactiveOAuth2UserService implements
 				OidcUserInfo userInfo = authority.getUserInfo();
 				OidcUserInfo userInfo = authority.getUserInfo();
 				Set<GrantedAuthority> authorities = new HashSet<>();
 				Set<GrantedAuthority> authorities = new HashSet<>();
 				authorities.add(authority);
 				authorities.add(authority);
+				OAuth2AccessToken token = userRequest.getAccessToken();
+				for (String scope : token.getScopes()) {
+					authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope));
+				}
 				String userNameAttributeName = userRequest.getClientRegistration()
 				String userNameAttributeName = userRequest.getClientRegistration()
 							.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName();
 							.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName();
 				if (StringUtils.hasText(userNameAttributeName)) {
 				if (StringUtils.hasText(userNameAttributeName)) {

+ 6 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java

@@ -17,8 +17,6 @@ package org.springframework.security.oauth2.client.oidc.userinfo;
 
 
 import java.time.Instant;
 import java.time.Instant;
 import java.util.Arrays;
 import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.HashSet;
 import java.util.LinkedHashSet;
 import java.util.LinkedHashSet;
@@ -29,11 +27,13 @@ import java.util.function.Function;
 import org.springframework.core.convert.TypeDescriptor;
 import org.springframework.core.convert.TypeDescriptor;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
 import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.converter.ClaimConversionService;
 import org.springframework.security.oauth2.core.converter.ClaimConversionService;
@@ -96,7 +96,6 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 	public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
 	public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
 		Assert.notNull(userRequest, "userRequest cannot be null");
 		Assert.notNull(userRequest, "userRequest cannot be null");
 		OidcUserInfo userInfo = null;
 		OidcUserInfo userInfo = null;
-		Collection<? extends GrantedAuthority> oauth2UserAuthorities = Collections.emptyList();
 		if (this.shouldRetrieveUserInfo(userRequest)) {
 		if (this.shouldRetrieveUserInfo(userRequest)) {
 			OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest);
 			OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest);
 
 
@@ -109,7 +108,6 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 				claims = DEFAULT_CLAIM_TYPE_CONVERTER.convert(oauth2User.getAttributes());
 				claims = DEFAULT_CLAIM_TYPE_CONVERTER.convert(oauth2User.getAttributes());
 			}
 			}
 			userInfo = new OidcUserInfo(claims);
 			userInfo = new OidcUserInfo(claims);
-			oauth2UserAuthorities = oauth2User.getAuthorities();
 
 
 			// https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
 			// https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
 
 
@@ -133,7 +131,10 @@ public class OidcUserService implements OAuth2UserService<OidcUserRequest, OidcU
 
 
 		Set<GrantedAuthority> authorities = new LinkedHashSet<>();
 		Set<GrantedAuthority> authorities = new LinkedHashSet<>();
 		authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo));
 		authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo));
-		authorities.addAll(oauth2UserAuthorities);
+		OAuth2AccessToken token = userRequest.getAccessToken();
+		for (String authority : token.getScopes()) {
+			authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
+		}
 
 
 		OidcUser user;
 		OidcUser user;
 
 

+ 3 - 38
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java

@@ -15,9 +15,6 @@
  */
  */
 package org.springframework.security.oauth2.client.userinfo;
 package org.springframework.security.oauth2.client.userinfo;
 
 
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
 import java.util.LinkedHashSet;
 import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Map;
 import java.util.Set;
 import java.util.Set;
@@ -30,7 +27,7 @@ import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
 import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
-import org.springframework.security.oauth2.core.ClaimAccessor;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -71,9 +68,6 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
 	private static final ParameterizedTypeReference<Map<String, Object>> PARAMETERIZED_RESPONSE_TYPE =
 	private static final ParameterizedTypeReference<Map<String, Object>> PARAMETERIZED_RESPONSE_TYPE =
 			new ParameterizedTypeReference<Map<String, Object>>() {};
 			new ParameterizedTypeReference<Map<String, Object>>() {};
 
 
-	private static final Collection<String> WELL_KNOWN_AUTHORITIES_CLAIM_NAMES =
-			Arrays.asList("scope", "scp");
-
 	private Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = new OAuth2UserRequestEntityConverter();
 	private Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = new OAuth2UserRequestEntityConverter();
 
 
 	private RestOperations restOperations;
 	private RestOperations restOperations;
@@ -137,7 +131,8 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
 		Map<String, Object> userAttributes = response.getBody();
 		Map<String, Object> userAttributes = response.getBody();
 		Set<GrantedAuthority> authorities = new LinkedHashSet<>();
 		Set<GrantedAuthority> authorities = new LinkedHashSet<>();
 		authorities.add(new OAuth2UserAuthority(userAttributes));
 		authorities.add(new OAuth2UserAuthority(userAttributes));
-		for (String authority : getAuthorities(() -> userAttributes)) {
+		OAuth2AccessToken token = userRequest.getAccessToken();
+		for (String authority : token.getScopes()) {
 			authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
 			authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
 		}
 		}
 
 
@@ -172,34 +167,4 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq
 		Assert.notNull(restOperations, "restOperations cannot be null");
 		Assert.notNull(restOperations, "restOperations cannot be null");
 		this.restOperations = restOperations;
 		this.restOperations = restOperations;
 	}
 	}
-
-	private String getAuthoritiesClaimName(ClaimAccessor claims) {
-		for (String claimName : WELL_KNOWN_AUTHORITIES_CLAIM_NAMES) {
-			if (claims.containsClaim(claimName)) {
-				return claimName;
-			}
-		}
-		return null;
-	}
-
-	private Collection<String> getAuthorities(ClaimAccessor claims) {
-		String claimName = getAuthoritiesClaimName(claims);
-
-		if (claimName == null) {
-			return Collections.emptyList();
-		}
-
-		Object authorities = claims.getClaim(claimName);
-		if (authorities instanceof String) {
-			if (StringUtils.hasText((String) authorities)) {
-				return Arrays.asList(((String) authorities).split(" "));
-			} else {
-				return Collections.emptyList();
-			}
-		} else if (authorities instanceof Collection) {
-			return (Collection<String>) authorities;
-		}
-
-		return Collections.emptyList();
-	}
 }
 }

+ 6 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java

@@ -28,7 +28,9 @@ import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.http.MediaType;
 import org.springframework.security.authentication.AuthenticationServiceException;
 import org.springframework.security.authentication.AuthenticationServiceException;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.core.AuthenticationMethod;
 import org.springframework.security.oauth2.core.AuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
 import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
@@ -131,6 +133,10 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi
 				GrantedAuthority authority = new OAuth2UserAuthority(attrs);
 				GrantedAuthority authority = new OAuth2UserAuthority(attrs);
 				Set<GrantedAuthority> authorities = new HashSet<>();
 				Set<GrantedAuthority> authorities = new HashSet<>();
 				authorities.add(authority);
 				authorities.add(authority);
+				OAuth2AccessToken token = userRequest.getAccessToken();
+				for (String scope : token.getScopes()) {
+					authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope));
+				}
 
 
 				return new DefaultOAuth2User(authorities, attrs, userNameAttributeName);
 				return new DefaultOAuth2User(authorities, attrs, userNameAttributeName);
 			})
 			})

+ 58 - 11
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java

@@ -16,13 +16,25 @@
 
 
 package org.springframework.security.oauth2.client.oidc.userinfo;
 package org.springframework.security.oauth2.client.oidc.userinfo;
 
 
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.function.Function;
+
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 import org.mockito.junit.MockitoJUnitRunner;
+import reactor.core.publisher.Mono;
+
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.core.convert.converter.Converter;
+import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
@@ -36,17 +48,20 @@ import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
 import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2User;
-import reactor.core.publisher.Mono;
-
-import java.time.Duration;
-import java.time.Instant;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.function.Function;
-
-import static org.assertj.core.api.Assertions.*;
-import static org.mockito.Mockito.*;
+import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.same;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
+import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
+import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes;
+import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken;
 
 
 /**
 /**
  * @author Rob Winch
  * @author Rob Winch
@@ -178,6 +193,38 @@ public class OidcReactiveOAuth2UserServiceTests {
 		verify(customClaimTypeConverterFactory).apply(same(userRequest.getClientRegistration()));
 		verify(customClaimTypeConverterFactory).apply(same(userRequest.getClientRegistration()));
 	}
 	}
 
 
+	@Test
+	public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() {
+		Map<String, Object> body = new HashMap<>();
+		body.put("id", "id");
+		body.put("sub", "test-subject");
+		OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService();
+		OidcUserRequest request = new OidcUserRequest(
+				clientRegistration().build(), scopes("message:read", "message:write"), idToken(body));
+		OidcUser user = userService.loadUser(request).block();
+
+		assertThat(user.getAuthorities()).hasSize(3);
+		Iterator<? extends GrantedAuthority> authorities = user.getAuthorities().iterator();
+		assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class);
+		assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read"));
+		assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write"));
+	}
+
+	@Test
+	public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() {
+		Map<String, Object> body = new HashMap<>();
+		body.put("id", "id");
+		body.put("sub", "test-subject");
+		OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService();
+		OidcUserRequest request = new OidcUserRequest(
+				clientRegistration().build(), noScopes(), idToken(body));
+		OidcUser user = userService.loadUser(request).block();
+
+		assertThat(user.getAuthorities()).hasSize(1);
+		Iterator<? extends GrantedAuthority> authorities = user.getAuthorities().iterator();
+		assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class);
+	}
+
 	private OidcUserRequest userRequest() {
 	private OidcUserRequest userRequest() {
 		return new OidcUserRequest(this.registration.build(), this.accessToken, this.idToken);
 		return new OidcUserRequest(this.registration.build(), this.accessToken, this.idToken);
 	}
 	}

+ 8 - 51
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java

@@ -16,7 +16,6 @@
 package org.springframework.security.oauth2.client.oidc.userinfo;
 package org.springframework.security.oauth2.client.oidc.userinfo;
 
 
 import java.time.Instant;
 import java.time.Instant;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Iterator;
@@ -33,19 +32,14 @@ import org.junit.Rule;
 import org.junit.Test;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
 import org.junit.rules.ExpectedException;
 
 
-import org.springframework.core.ParameterizedTypeReference;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpMethod;
-import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.http.MediaType;
-import org.springframework.http.RequestEntity;
-import org.springframework.http.ResponseEntity;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
 import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
-import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
 import org.springframework.security.oauth2.core.AuthenticationMethod;
 import org.springframework.security.oauth2.core.AuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
@@ -56,18 +50,16 @@ import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
 import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
 import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
-import org.springframework.web.client.RestOperations;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.CoreMatchers.containsString;
-import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.nullable;
 import static org.mockito.Mockito.same;
 import static org.mockito.Mockito.same;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
 import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
+import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
 import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes;
 import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes;
 import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken;
 import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken;
 
 
@@ -272,7 +264,7 @@ public class OidcUserServiceTests {
 		assertThat(user.getUserInfo().getPreferredUsername()).isEqualTo("user1");
 		assertThat(user.getUserInfo().getPreferredUsername()).isEqualTo("user1");
 		assertThat(user.getUserInfo().getEmail()).isEqualTo("user1@example.com");
 		assertThat(user.getUserInfo().getEmail()).isEqualTo("user1@example.com");
 
 
-		assertThat(user.getAuthorities().size()).isEqualTo(1);
+		assertThat(user.getAuthorities().size()).isEqualTo(3);
 		assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OidcUserAuthority.class);
 		assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OidcUserAuthority.class);
 		OidcUserAuthority userAuthority = (OidcUserAuthority) user.getAuthorities().iterator().next();
 		OidcUserAuthority userAuthority = (OidcUserAuthority) user.getAuthorities().iterator().next();
 		assertThat(userAuthority.getAuthority()).isEqualTo("ROLE_USER");
 		assertThat(userAuthority.getAuthority()).isEqualTo("ROLE_USER");
@@ -499,15 +491,13 @@ public class OidcUserServiceTests {
 	}
 	}
 
 
 	@Test
 	@Test
-	public void loadUserWhenAttributesContainScopeThenIndividualScopeAuthorities() {
+	public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() {
 		Map<String, Object> body = new HashMap<>();
 		Map<String, Object> body = new HashMap<>();
 		body.put("id", "id");
 		body.put("id", "id");
 		body.put("sub", "test-subject");
 		body.put("sub", "test-subject");
-		body.put("scope", "message:read message:write");
 		OidcUserService userService = new OidcUserService();
 		OidcUserService userService = new OidcUserService();
-		userService.setOauth2UserService(withMockResponse(body));
-		OidcUserRequest request = new OidcUserRequest(clientRegistration().
-				userInfoUri("uri").build(), scopes("profile"), idToken(body));
+		OidcUserRequest request = new OidcUserRequest(clientRegistration().build(),
+				scopes("message:read", "message:write"), idToken(body));
 		OidcUser user = userService.loadUser(request);
 		OidcUser user = userService.loadUser(request);
 
 
 		assertThat(user.getAuthorities()).hasSize(3);
 		assertThat(user.getAuthorities()).hasSize(3);
@@ -518,34 +508,13 @@ public class OidcUserServiceTests {
 	}
 	}
 
 
 	@Test
 	@Test
-	public void loadUserWhenAttributesContainScpThenIndividualScopeAuthorities() {
+	public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() {
 		Map<String, Object> body = new HashMap<>();
 		Map<String, Object> body = new HashMap<>();
 		body.put("id", "id");
 		body.put("id", "id");
 		body.put("sub", "test-subject");
 		body.put("sub", "test-subject");
-		body.put("scp", Arrays.asList("message:read", "message:write"));
 		OidcUserService userService = new OidcUserService();
 		OidcUserService userService = new OidcUserService();
-		userService.setOauth2UserService(withMockResponse(body));
-		OidcUserRequest request = new OidcUserRequest(clientRegistration().
-				userInfoUri("uri").build(), scopes("profile"), idToken(body));
-		OidcUser user = userService.loadUser(request);
-
-		assertThat(user.getAuthorities()).hasSize(3);
-		Iterator<? extends GrantedAuthority> authorities = user.getAuthorities().iterator();
-		assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class);
-		assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read"));
-		assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write"));
-	}
-
-	@Test
-	public void loadUserWhenAttributesDoesNotContainScopesThenNoScopeAuthorities() {
-		Map<String, Object> body = new HashMap<>();
-		body.put("id", "id");
-		body.put("sub", "test-subject");
-		body.put("authorities", Arrays.asList("message:read", "message:write"));
-		OidcUserService userService = new OidcUserService();
-		userService.setOauth2UserService(withMockResponse(body));
-		OidcUserRequest request = new OidcUserRequest(clientRegistration().
-				userInfoUri("uri").build(), scopes("profile"), idToken(body));
+		OidcUserRequest request = new OidcUserRequest(clientRegistration().build(),
+				noScopes(), idToken(body));
 		OidcUser user = userService.loadUser(request);
 		OidcUser user = userService.loadUser(request);
 
 
 		assertThat(user.getAuthorities()).hasSize(1);
 		assertThat(user.getAuthorities()).hasSize(1);
@@ -553,18 +522,6 @@ public class OidcUserServiceTests {
 		assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class);
 		assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class);
 	}
 	}
 
 
-	private DefaultOAuth2UserService withMockResponse(Map<String, Object> response) {
-		ResponseEntity<Map<String, Object>> responseEntity = new ResponseEntity<>(response, HttpStatus.OK);
-		Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = mock(Converter.class);
-		RestOperations rest = mock(RestOperations.class);
-		when(rest.exchange(nullable(RequestEntity.class), any(ParameterizedTypeReference.class)))
-				.thenReturn(responseEntity);
-		DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
-		userService.setRequestEntityConverter(requestEntityConverter);
-		userService.setRestOperations(rest);
-		return userService;
-	}
-
 	private MockResponse jsonResponse(String json) {
 	private MockResponse jsonResponse(String json) {
 		return new MockResponse()
 		return new MockResponse()
 				.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
 				.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)

+ 7 - 23
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java

@@ -15,7 +15,6 @@
  */
  */
 package org.springframework.security.oauth2.client.userinfo;
 package org.springframework.security.oauth2.client.userinfo;
 
 
-import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Map;
@@ -56,6 +55,7 @@ import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
 import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
 import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
 import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
+import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes;
 
 
 /**
 /**
  * Tests for {@link DefaultOAuth2UserService}.
  * Tests for {@link DefaultOAuth2UserService}.
@@ -342,12 +342,12 @@ public class DefaultOAuth2UserServiceTests {
 	}
 	}
 
 
 	@Test
 	@Test
-	public void loadUserWhenAttributesContainScopeThenIndividualScopeAuthorities() {
+	public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() {
 		Map<String, Object> body = new HashMap<>();
 		Map<String, Object> body = new HashMap<>();
 		body.put("id", "id");
 		body.put("id", "id");
-		body.put("scope", "message:read message:write");
 		DefaultOAuth2UserService userService = withMockResponse(body);
 		DefaultOAuth2UserService userService = withMockResponse(body);
-		OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes());
+		OAuth2UserRequest request = new OAuth2UserRequest(
+				clientRegistration().build(), scopes("message:read", "message:write"));
 		OAuth2User user = userService.loadUser(request);
 		OAuth2User user = userService.loadUser(request);
 
 
 		assertThat(user.getAuthorities()).hasSize(3);
 		assertThat(user.getAuthorities()).hasSize(3);
@@ -358,28 +358,12 @@ public class DefaultOAuth2UserServiceTests {
 	}
 	}
 
 
 	@Test
 	@Test
-	public void loadUserWhenAttributesContainScpThenIndividualScopeAuthorities() {
+	public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() {
 		Map<String, Object> body = new HashMap<>();
 		Map<String, Object> body = new HashMap<>();
 		body.put("id", "id");
 		body.put("id", "id");
-		body.put("scp", Arrays.asList("message:read", "message:write"));
 		DefaultOAuth2UserService userService = withMockResponse(body);
 		DefaultOAuth2UserService userService = withMockResponse(body);
-		OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes());
-		OAuth2User user = userService.loadUser(request);
-
-		assertThat(user.getAuthorities()).hasSize(3);
-		Iterator<? extends GrantedAuthority> authorities = user.getAuthorities().iterator();
-		assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class);
-		assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read"));
-		assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write"));
-	}
-
-	@Test
-	public void loadUserWhenAttributesDoesNotContainScopesThenNoScopeAuthorities() {
-		Map<String, Object> body = new HashMap<>();
-		body.put("id", "id");
-		body.put("authorities", Arrays.asList("message:read", "message:write"));
-		DefaultOAuth2UserService userService = withMockResponse(body);
-		OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes());
+		OAuth2UserRequest request = new OAuth2UserRequest(
+				clientRegistration().build(), noScopes());
 		OAuth2User user = userService.loadUser(request);
 		OAuth2User user = userService.loadUser(request);
 
 
 		assertThat(user.getAuthorities()).hasSize(1);
 		assertThat(user.getAuthorities()).hasSize(1);

+ 73 - 8
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java

@@ -16,15 +16,30 @@
 
 
 package org.springframework.security.oauth2.client.userinfo;
 package org.springframework.security.oauth2.client.userinfo;
 
 
+import java.time.Duration;
+import java.time.Instant;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.function.Predicate;
+
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
 import okhttp3.mockwebserver.MockWebServer;
+import okhttp3.mockwebserver.RecordedRequest;
 import org.junit.After;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+import org.springframework.core.ParameterizedTypeReference;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
 import org.springframework.http.MediaType;
 import org.springframework.security.authentication.AuthenticationServiceException;
 import org.springframework.security.authentication.AuthenticationServiceException;
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.AuthenticationMethod;
 import org.springframework.security.oauth2.core.AuthenticationMethod;
@@ -32,14 +47,17 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
 import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
-
-import okhttp3.mockwebserver.RecordedRequest;
-import reactor.test.StepVerifier;
-
-import java.time.Duration;
-import java.time.Instant;
-
-import static org.assertj.core.api.Assertions.*;
+import org.springframework.web.reactive.function.client.WebClient;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
+import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
+import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes;
 
 
 /**
 /**
  * @author Rob Winch
  * @author Rob Winch
@@ -211,6 +229,53 @@ public class DefaultReactiveOAuth2UserServiceTests {
 				.isInstanceOf(AuthenticationServiceException.class);
 				.isInstanceOf(AuthenticationServiceException.class);
 	}
 	}
 
 
+	@Test
+	public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() {
+		Map<String, Object> body = new HashMap<>();
+		body.put("id", "id");
+		DefaultReactiveOAuth2UserService userService = withMockResponse(body);
+		OAuth2UserRequest request = new OAuth2UserRequest(
+				clientRegistration().build(), scopes("message:read", "message:write"));
+		OAuth2User user = userService.loadUser(request).block();
+
+		assertThat(user.getAuthorities()).hasSize(3);
+		Iterator<? extends GrantedAuthority> authorities = user.getAuthorities().iterator();
+		assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class);
+		assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read"));
+		assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write"));
+	}
+
+	@Test
+	public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() {
+		Map<String, Object> body = new HashMap<>();
+		body.put("id", "id");
+		DefaultReactiveOAuth2UserService userService = withMockResponse(body);
+		OAuth2UserRequest request = new OAuth2UserRequest(
+				clientRegistration().build(), noScopes());
+		OAuth2User user = userService.loadUser(request).block();
+
+		assertThat(user.getAuthorities()).hasSize(1);
+		Iterator<? extends GrantedAuthority> authorities = user.getAuthorities().iterator();
+		assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class);
+	}
+
+	private DefaultReactiveOAuth2UserService withMockResponse(Map<String, Object> body) {
+		WebClient real = WebClient.builder().build();
+		WebClient.RequestHeadersUriSpec spec = spy(real.post());
+		WebClient rest = spy(WebClient.class);
+		WebClient.ResponseSpec clientResponse = mock(WebClient.ResponseSpec.class);
+		when(rest.get()).thenReturn(spec);
+		when(spec.retrieve()).thenReturn(clientResponse);
+		when(clientResponse.onStatus(any(Predicate.class), any(Function.class)))
+				.thenReturn(clientResponse);
+		when(clientResponse.bodyToMono(any(ParameterizedTypeReference.class)))
+				.thenReturn(Mono.just(body));
+
+		DefaultReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService();
+		userService.setWebClient(rest);
+		return userService;
+	}
+
 	private OAuth2UserRequest oauth2UserRequest() {
 	private OAuth2UserRequest oauth2UserRequest() {
 		return new OAuth2UserRequest(this.clientRegistration.build(), this.accessToken);
 		return new OAuth2UserRequest(this.clientRegistration.build(), this.accessToken);
 	}
 	}