|
@@ -20,6 +20,7 @@ import org.junit.Rule;
|
|
|
import org.junit.Test;
|
|
|
import org.junit.rules.ExpectedException;
|
|
|
import org.junit.runner.RunWith;
|
|
|
+import org.mockito.ArgumentCaptor;
|
|
|
import org.mockito.stubbing.Answer;
|
|
|
import org.powermock.api.mockito.PowerMockito;
|
|
|
import org.powermock.core.classloader.annotations.PrepareForTest;
|
|
@@ -37,7 +38,6 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|
|
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
|
|
import org.springframework.security.oauth2.core.OAuth2Error;
|
|
|
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
|
|
|
-import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
|
@@ -55,6 +55,7 @@ import java.util.HashMap;
|
|
|
import java.util.LinkedHashSet;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
+import java.util.Set;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.hamcrest.CoreMatchers.containsString;
|
|
@@ -78,8 +79,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|
|
private OAuth2AuthorizationExchange authorizationExchange;
|
|
|
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
|
|
|
private OAuth2AccessTokenResponse accessTokenResponse;
|
|
|
- private OAuth2AccessToken accessToken;
|
|
|
- private OAuth2RefreshToken refreshToken;
|
|
|
private OAuth2UserService<OidcUserRequest, OidcUser> userService;
|
|
|
private OidcAuthorizationCodeAuthenticationProvider authenticationProvider;
|
|
|
|
|
@@ -95,9 +94,7 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|
|
this.authorizationResponse = mock(OAuth2AuthorizationResponse.class);
|
|
|
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationResponse);
|
|
|
this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
|
|
|
- this.accessTokenResponse = mock(OAuth2AccessTokenResponse.class);
|
|
|
- this.accessToken = mock(OAuth2AccessToken.class);
|
|
|
- this.refreshToken = mock(OAuth2RefreshToken.class);
|
|
|
+ this.accessTokenResponse = this.accessTokenSuccessResponse();
|
|
|
this.userService = mock(OAuth2UserService.class);
|
|
|
this.authenticationProvider = PowerMockito.spy(
|
|
|
new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, this.userService));
|
|
@@ -111,11 +108,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|
|
when(this.authorizationResponse.getState()).thenReturn("12345");
|
|
|
when(this.authorizationRequest.getRedirectUri()).thenReturn("http://example.com");
|
|
|
when(this.authorizationResponse.getRedirectUri()).thenReturn("http://example.com");
|
|
|
- when(this.accessTokenResponse.getAccessToken()).thenReturn(this.accessToken);
|
|
|
- when(this.accessTokenResponse.getRefreshToken()).thenReturn(this.refreshToken);
|
|
|
- Map<String, Object> additionalParameters = new HashMap<>();
|
|
|
- additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
|
|
|
- when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(additionalParameters);
|
|
|
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(this.accessTokenResponse);
|
|
|
}
|
|
|
|
|
@@ -194,7 +186,11 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|
|
this.exception.expect(OAuth2AuthenticationException.class);
|
|
|
this.exception.expectMessage(containsString("invalid_id_token"));
|
|
|
|
|
|
- when(this.accessTokenResponse.getAdditionalParameters()).thenReturn(Collections.emptyMap());
|
|
|
+ OAuth2AccessTokenResponse accessTokenResponse =
|
|
|
+ OAuth2AccessTokenResponse.withResponse(this.accessTokenSuccessResponse())
|
|
|
+ .additionalParameters(Collections.emptyMap())
|
|
|
+ .build();
|
|
|
+ when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
|
|
|
|
|
|
this.authenticationProvider.authenticate(
|
|
|
new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
|
|
@@ -368,8 +364,8 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|
|
assertThat(authentication.getAuthorities()).isEqualTo(authorities);
|
|
|
assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration);
|
|
|
assertThat(authentication.getAuthorizationExchange()).isEqualTo(this.authorizationExchange);
|
|
|
- assertThat(authentication.getAccessToken()).isEqualTo(this.accessToken);
|
|
|
- assertThat(authentication.getRefreshToken()).isEqualTo(this.refreshToken);
|
|
|
+ assertThat(authentication.getAccessToken()).isEqualTo(this.accessTokenResponse.getAccessToken());
|
|
|
+ assertThat(authentication.getRefreshToken()).isEqualTo(this.accessTokenResponse.getRefreshToken());
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -400,6 +396,30 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|
|
assertThat(authentication.getAuthorities()).isEqualTo(mappedAuthorities);
|
|
|
}
|
|
|
|
|
|
+ // gh-5368
|
|
|
+ @Test
|
|
|
+ public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() throws Exception {
|
|
|
+ Map<String, Object> claims = new HashMap<>();
|
|
|
+ claims.put(IdTokenClaimNames.ISS, "https://provider.com");
|
|
|
+ claims.put(IdTokenClaimNames.SUB, "subject1");
|
|
|
+ claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
|
|
|
+ claims.put(IdTokenClaimNames.AZP, "client1");
|
|
|
+ this.setUpIdToken(claims);
|
|
|
+
|
|
|
+ OidcUser principal = mock(OidcUser.class);
|
|
|
+ List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
|
|
|
+ when(principal.getAuthorities()).thenAnswer(
|
|
|
+ (Answer<List<GrantedAuthority>>) invocation -> authorities);
|
|
|
+ ArgumentCaptor<OidcUserRequest> userRequestArgCaptor = ArgumentCaptor.forClass(OidcUserRequest.class);
|
|
|
+ when(this.userService.loadUser(userRequestArgCaptor.capture())).thenReturn(principal);
|
|
|
+
|
|
|
+ this.authenticationProvider.authenticate(new OAuth2LoginAuthenticationToken(
|
|
|
+ this.clientRegistration, this.authorizationExchange));
|
|
|
+
|
|
|
+ assertThat(userRequestArgCaptor.getValue().getAdditionalParameters()).containsAllEntriesOf(
|
|
|
+ this.accessTokenResponse.getAdditionalParameters());
|
|
|
+ }
|
|
|
+
|
|
|
private void setUpIdToken(Map<String, Object> claims) throws Exception {
|
|
|
Instant issuedAt = Instant.now();
|
|
|
Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
|
|
@@ -416,4 +436,23 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
|
|
|
when(jwtDecoder.decode(anyString())).thenReturn(idToken);
|
|
|
PowerMockito.doReturn(jwtDecoder).when(this.authenticationProvider, "getJwtDecoder", any(ClientRegistration.class));
|
|
|
}
|
|
|
+
|
|
|
+ private OAuth2AccessTokenResponse accessTokenSuccessResponse() {
|
|
|
+ Instant expiresAt = Instant.now().plusSeconds(5);
|
|
|
+ Set<String> scopes = new LinkedHashSet<>(Arrays.asList("openid", "profile", "email"));
|
|
|
+ Map<String, Object> additionalParameters = new HashMap<>();
|
|
|
+ additionalParameters.put("param1", "value1");
|
|
|
+ additionalParameters.put("param2", "value2");
|
|
|
+ additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
|
|
|
+
|
|
|
+ return OAuth2AccessTokenResponse
|
|
|
+ .withToken("access-token-1234")
|
|
|
+ .tokenType(OAuth2AccessToken.TokenType.BEARER)
|
|
|
+ .expiresIn(expiresAt.getEpochSecond())
|
|
|
+ .scopes(scopes)
|
|
|
+ .refreshToken("refresh-token-1234")
|
|
|
+ .additionalParameters(additionalParameters)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ }
|
|
|
}
|