|
@@ -15,6 +15,14 @@
|
|
*/
|
|
*/
|
|
package org.springframework.security.oauth2.client.oidc.userinfo;
|
|
package org.springframework.security.oauth2.client.oidc.userinfo;
|
|
|
|
|
|
|
|
+import java.time.Instant;
|
|
|
|
+import java.util.Arrays;
|
|
|
|
+import java.util.HashMap;
|
|
|
|
+import java.util.LinkedHashSet;
|
|
|
|
+import java.util.Map;
|
|
|
|
+import java.util.Set;
|
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
|
+
|
|
import okhttp3.mockwebserver.MockResponse;
|
|
import okhttp3.mockwebserver.MockResponse;
|
|
import okhttp3.mockwebserver.MockWebServer;
|
|
import okhttp3.mockwebserver.MockWebServer;
|
|
import okhttp3.mockwebserver.RecordedRequest;
|
|
import okhttp3.mockwebserver.RecordedRequest;
|
|
@@ -23,17 +31,13 @@ import org.junit.Before;
|
|
import org.junit.Rule;
|
|
import org.junit.Rule;
|
|
import org.junit.Test;
|
|
import org.junit.Test;
|
|
import org.junit.rules.ExpectedException;
|
|
import org.junit.rules.ExpectedException;
|
|
-import org.junit.runner.RunWith;
|
|
|
|
-import org.powermock.core.classloader.annotations.PowerMockIgnore;
|
|
|
|
-import org.powermock.core.classloader.annotations.PrepareForTest;
|
|
|
|
-import org.powermock.modules.junit4.PowerMockRunner;
|
|
|
|
|
|
+
|
|
import org.springframework.http.HttpHeaders;
|
|
import org.springframework.http.HttpHeaders;
|
|
import org.springframework.http.HttpMethod;
|
|
import org.springframework.http.HttpMethod;
|
|
import org.springframework.http.MediaType;
|
|
import org.springframework.http.MediaType;
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
|
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
|
|
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
|
|
import org.springframework.security.oauth2.core.AuthenticationMethod;
|
|
import org.springframework.security.oauth2.core.AuthenticationMethod;
|
|
-import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
|
|
|
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
|
|
import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
|
|
@@ -43,31 +47,19 @@ import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
|
|
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
|
|
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
|
|
import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
|
|
import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
|
|
|
|
|
|
-import java.util.Arrays;
|
|
|
|
-import java.util.HashMap;
|
|
|
|
-import java.util.LinkedHashSet;
|
|
|
|
-import java.util.Map;
|
|
|
|
-import java.util.Set;
|
|
|
|
-import java.util.concurrent.TimeUnit;
|
|
|
|
-
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
|
import static org.hamcrest.CoreMatchers.containsString;
|
|
import static org.hamcrest.CoreMatchers.containsString;
|
|
-import static org.mockito.Mockito.mock;
|
|
|
|
-import static org.mockito.Mockito.when;
|
|
|
|
|
|
+import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
|
|
|
|
+import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes;
|
|
|
|
|
|
/**
|
|
/**
|
|
* Tests for {@link OidcUserService}.
|
|
* Tests for {@link OidcUserService}.
|
|
*
|
|
*
|
|
* @author Joe Grandja
|
|
* @author Joe Grandja
|
|
*/
|
|
*/
|
|
-@PowerMockIgnore({"okhttp3.*", "okio.Buffer"})
|
|
|
|
-@PrepareForTest(ClientRegistration.class)
|
|
|
|
-@RunWith(PowerMockRunner.class)
|
|
|
|
public class OidcUserServiceTests {
|
|
public class OidcUserServiceTests {
|
|
- private ClientRegistration clientRegistration;
|
|
|
|
- private ClientRegistration.ProviderDetails providerDetails;
|
|
|
|
- private ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint;
|
|
|
|
|
|
+ private ClientRegistration.Builder clientRegistrationBuilder;
|
|
private OAuth2AccessToken accessToken;
|
|
private OAuth2AccessToken accessToken;
|
|
private OidcIdToken idToken;
|
|
private OidcIdToken idToken;
|
|
private OidcUserService userService = new OidcUserService();
|
|
private OidcUserService userService = new OidcUserService();
|
|
@@ -80,26 +72,17 @@ public class OidcUserServiceTests {
|
|
public void setup() throws Exception {
|
|
public void setup() throws Exception {
|
|
this.server = new MockWebServer();
|
|
this.server = new MockWebServer();
|
|
this.server.start();
|
|
this.server.start();
|
|
- this.clientRegistration = mock(ClientRegistration.class);
|
|
|
|
- this.providerDetails = mock(ClientRegistration.ProviderDetails.class);
|
|
|
|
- this.userInfoEndpoint = mock(ClientRegistration.ProviderDetails.UserInfoEndpoint.class);
|
|
|
|
- when(this.clientRegistration.getProviderDetails()).thenReturn(this.providerDetails);
|
|
|
|
- when(this.providerDetails.getUserInfoEndpoint()).thenReturn(this.userInfoEndpoint);
|
|
|
|
- when(this.clientRegistration.getAuthorizationGrantType()).thenReturn(AuthorizationGrantType.AUTHORIZATION_CODE);
|
|
|
|
|
|
+ this.clientRegistrationBuilder = clientRegistration()
|
|
|
|
+ .userInfoUri(null)
|
|
|
|
+ .userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
|
|
|
|
+ .userNameAttributeName(StandardClaimNames.SUB);
|
|
|
|
|
|
- when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
|
|
|
|
- when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.SUB);
|
|
|
|
|
|
+ this.accessToken = scopes(OidcScopes.OPENID, OidcScopes.PROFILE);
|
|
|
|
|
|
- this.accessToken = mock(OAuth2AccessToken.class);
|
|
|
|
- Set<String> authorizedScopes = new LinkedHashSet<>(Arrays.asList(OidcScopes.OPENID, OidcScopes.PROFILE));
|
|
|
|
- when(this.accessToken.getScopes()).thenReturn(authorizedScopes);
|
|
|
|
-
|
|
|
|
- this.idToken = mock(OidcIdToken.class);
|
|
|
|
Map<String, Object> idTokenClaims = new HashMap<>();
|
|
Map<String, Object> idTokenClaims = new HashMap<>();
|
|
idTokenClaims.put(IdTokenClaimNames.ISS, "https://provider.com");
|
|
idTokenClaims.put(IdTokenClaimNames.ISS, "https://provider.com");
|
|
idTokenClaims.put(IdTokenClaimNames.SUB, "subject1");
|
|
idTokenClaims.put(IdTokenClaimNames.SUB, "subject1");
|
|
- when(this.idToken.getClaims()).thenReturn(idTokenClaims);
|
|
|
|
- when(this.idToken.getSubject()).thenReturn("subject1");
|
|
|
|
|
|
+ this.idToken = new OidcIdToken("access-token", Instant.MIN, Instant.MAX, idTokenClaims);
|
|
|
|
|
|
this.userService.setOauth2UserService(new DefaultOAuth2UserService());
|
|
this.userService.setOauth2UserService(new DefaultOAuth2UserService());
|
|
}
|
|
}
|
|
@@ -123,22 +106,23 @@ public class OidcUserServiceTests {
|
|
|
|
|
|
@Test
|
|
@Test
|
|
public void loadUserWhenUserInfoUriIsNullThenUserInfoEndpointNotRequested() {
|
|
public void loadUserWhenUserInfoUriIsNullThenUserInfoEndpointNotRequested() {
|
|
- when(this.userInfoEndpoint.getUri()).thenReturn(null);
|
|
|
|
-
|
|
|
|
OidcUser user = this.userService.loadUser(
|
|
OidcUser user = this.userService.loadUser(
|
|
- new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
+ new OidcUserRequest(this.clientRegistrationBuilder.build(), this.accessToken, this.idToken));
|
|
assertThat(user.getUserInfo()).isNull();
|
|
assertThat(user.getUserInfo()).isNull();
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
public void loadUserWhenAuthorizedScopesDoesNotContainUserInfoScopesThenUserInfoEndpointNotRequested() {
|
|
public void loadUserWhenAuthorizedScopesDoesNotContainUserInfoScopesThenUserInfoEndpointNotRequested() {
|
|
- Set<String> authorizedScopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
|
|
|
|
- when(this.accessToken.getScopes()).thenReturn(authorizedScopes);
|
|
|
|
|
|
+ ClientRegistration clientRegistration = this.clientRegistrationBuilder
|
|
|
|
+ .userInfoUri("http://provider.com/user").build();
|
|
|
|
|
|
- when(this.userInfoEndpoint.getUri()).thenReturn("http://provider.com/user");
|
|
|
|
|
|
+ Set<String> authorizedScopes = new LinkedHashSet<>(Arrays.asList("scope1", "scope2"));
|
|
|
|
+ OAuth2AccessToken accessToken = new OAuth2AccessToken(
|
|
|
|
+ OAuth2AccessToken.TokenType.BEARER, "access-token",
|
|
|
|
+ Instant.MIN, Instant.MAX, authorizedScopes);
|
|
|
|
|
|
OidcUser user = this.userService.loadUser(
|
|
OidcUser user = this.userService.loadUser(
|
|
- new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
+ new OidcUserRequest(clientRegistration, accessToken, this.idToken));
|
|
assertThat(user.getUserInfo()).isNull();
|
|
assertThat(user.getUserInfo()).isNull();
|
|
}
|
|
}
|
|
|
|
|
|
@@ -156,11 +140,11 @@ public class OidcUserServiceTests {
|
|
|
|
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
|
|
|
|
- when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
|
|
|
- when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
|
|
|
|
|
+ ClientRegistration clientRegistration = this.clientRegistrationBuilder
|
|
|
|
+ .userInfoUri(userInfoUri).build();
|
|
|
|
|
|
OidcUser user = this.userService.loadUser(
|
|
OidcUser user = this.userService.loadUser(
|
|
- new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
+ new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
assertThat(user.getIdToken()).isNotNull();
|
|
assertThat(user.getIdToken()).isNotNull();
|
|
assertThat(user.getUserInfo()).isNotNull();
|
|
assertThat(user.getUserInfo()).isNotNull();
|
|
@@ -196,11 +180,11 @@ public class OidcUserServiceTests {
|
|
|
|
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
|
|
|
|
- when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
|
|
|
- when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL);
|
|
|
|
- when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
|
|
|
|
|
+ ClientRegistration clientRegistration = this.clientRegistrationBuilder
|
|
|
|
+ .userInfoUri(userInfoUri)
|
|
|
|
+ .userNameAttributeName(StandardClaimNames.EMAIL).build();
|
|
|
|
|
|
- this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
+ this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
@@ -215,10 +199,10 @@ public class OidcUserServiceTests {
|
|
|
|
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
|
|
|
|
- when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
|
|
|
- when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
|
|
|
|
|
+ ClientRegistration clientRegistration = this.clientRegistrationBuilder
|
|
|
|
+ .userInfoUri(userInfoUri).build();
|
|
|
|
|
|
- this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
+ this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
@@ -238,10 +222,10 @@ public class OidcUserServiceTests {
|
|
|
|
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
|
|
|
|
- when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
|
|
|
- when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
|
|
|
|
|
+ ClientRegistration clientRegistration = this.clientRegistrationBuilder
|
|
|
|
+ .userInfoUri(userInfoUri).build();
|
|
|
|
|
|
- this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
+ this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
@@ -253,10 +237,10 @@ public class OidcUserServiceTests {
|
|
|
|
|
|
String userInfoUri = server.url("/user").toString();
|
|
String userInfoUri = server.url("/user").toString();
|
|
|
|
|
|
- when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
|
|
|
- when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
|
|
|
|
|
+ ClientRegistration clientRegistration = this.clientRegistrationBuilder
|
|
|
|
+ .userInfoUri(userInfoUri).build();
|
|
|
|
|
|
- this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
+ this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
@@ -266,10 +250,10 @@ public class OidcUserServiceTests {
|
|
|
|
|
|
String userInfoUri = "http://invalid-provider.com/user";
|
|
String userInfoUri = "http://invalid-provider.com/user";
|
|
|
|
|
|
- when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
|
|
|
- when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
|
|
|
|
|
+ ClientRegistration clientRegistration = this.clientRegistrationBuilder
|
|
|
|
+ .userInfoUri(userInfoUri).build();
|
|
|
|
|
|
- this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
+ this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
@@ -286,12 +270,12 @@ public class OidcUserServiceTests {
|
|
|
|
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
|
|
|
|
- when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
|
|
|
- when(this.userInfoEndpoint.getUserNameAttributeName()).thenReturn(StandardClaimNames.EMAIL);
|
|
|
|
- when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
|
|
|
|
|
+ ClientRegistration clientRegistration = this.clientRegistrationBuilder
|
|
|
|
+ .userInfoUri(userInfoUri)
|
|
|
|
+ .userNameAttributeName(StandardClaimNames.EMAIL).build();
|
|
|
|
|
|
OidcUser user = this.userService.loadUser(
|
|
OidcUser user = this.userService.loadUser(
|
|
- new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
+ new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
assertThat(user.getName()).isEqualTo("user1@example.com");
|
|
assertThat(user.getName()).isEqualTo("user1@example.com");
|
|
}
|
|
}
|
|
@@ -311,10 +295,10 @@ public class OidcUserServiceTests {
|
|
|
|
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
|
|
|
|
- when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
|
|
|
- when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
|
|
|
|
|
+ ClientRegistration clientRegistration = this.clientRegistrationBuilder
|
|
|
|
+ .userInfoUri(userInfoUri).build();
|
|
|
|
|
|
- this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
+ this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
|
|
assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
|
|
assertThat(this.server.takeRequest(1, TimeUnit.SECONDS).getHeader(HttpHeaders.ACCEPT))
|
|
.isEqualTo(MediaType.APPLICATION_JSON_VALUE);
|
|
.isEqualTo(MediaType.APPLICATION_JSON_VALUE);
|
|
}
|
|
}
|
|
@@ -334,11 +318,10 @@ public class OidcUserServiceTests {
|
|
|
|
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
|
|
|
|
- when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
|
|
|
- when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.HEADER);
|
|
|
|
- when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
|
|
|
|
|
+ ClientRegistration clientRegistration = this.clientRegistrationBuilder
|
|
|
|
+ .userInfoUri(userInfoUri).build();
|
|
|
|
|
|
- this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
+ this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
|
|
RecordedRequest request = this.server.takeRequest();
|
|
RecordedRequest request = this.server.takeRequest();
|
|
assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name());
|
|
assertThat(request.getMethod()).isEqualTo(HttpMethod.GET.name());
|
|
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
|
|
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
|
|
@@ -360,11 +343,11 @@ public class OidcUserServiceTests {
|
|
|
|
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
String userInfoUri = this.server.url("/user").toString();
|
|
|
|
|
|
- when(this.userInfoEndpoint.getUri()).thenReturn(userInfoUri);
|
|
|
|
- when(this.userInfoEndpoint.getAuthenticationMethod()).thenReturn(AuthenticationMethod.FORM);
|
|
|
|
- when(this.accessToken.getTokenValue()).thenReturn("access-token");
|
|
|
|
|
|
+ ClientRegistration clientRegistration = this.clientRegistrationBuilder
|
|
|
|
+ .userInfoUri(userInfoUri)
|
|
|
|
+ .userInfoAuthenticationMethod(AuthenticationMethod.FORM).build();
|
|
|
|
|
|
- this.userService.loadUser(new OidcUserRequest(this.clientRegistration, this.accessToken, this.idToken));
|
|
|
|
|
|
+ this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
|
|
RecordedRequest request = this.server.takeRequest();
|
|
RecordedRequest request = this.server.takeRequest();
|
|
assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name());
|
|
assertThat(request.getMethod()).isEqualTo(HttpMethod.POST.name());
|
|
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
|
|
assertThat(request.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_VALUE);
|