|
@@ -19,7 +19,9 @@ import java.security.Principal;
|
|
import java.time.Instant;
|
|
import java.time.Instant;
|
|
import java.time.temporal.ChronoUnit;
|
|
import java.time.temporal.ChronoUnit;
|
|
import java.util.Collections;
|
|
import java.util.Collections;
|
|
|
|
+import java.util.HashMap;
|
|
import java.util.HashSet;
|
|
import java.util.HashSet;
|
|
|
|
+import java.util.Map;
|
|
import java.util.Set;
|
|
import java.util.Set;
|
|
|
|
|
|
import org.junit.Before;
|
|
import org.junit.Before;
|
|
@@ -36,23 +38,28 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
|
import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
|
|
import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
|
|
import org.springframework.security.oauth2.core.OAuth2TokenType;
|
|
import org.springframework.security.oauth2.core.OAuth2TokenType;
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
|
|
|
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
|
|
|
|
+import org.springframework.security.oauth2.core.oidc.OidcScopes;
|
|
|
|
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
|
|
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
|
|
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
|
|
import org.springframework.security.oauth2.jwt.JoseHeaderNames;
|
|
import org.springframework.security.oauth2.jwt.JoseHeaderNames;
|
|
import org.springframework.security.oauth2.jwt.Jwt;
|
|
import org.springframework.security.oauth2.jwt.Jwt;
|
|
import org.springframework.security.oauth2.jwt.JwtEncoder;
|
|
import org.springframework.security.oauth2.jwt.JwtEncoder;
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
|
|
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
|
|
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
|
|
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
|
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
|
|
|
|
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
|
|
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
|
|
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
|
|
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
|
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
|
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
|
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
|
-import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
|
|
|
|
-import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
|
|
|
|
|
|
|
|
|
|
+import static org.assertj.core.api.Assertions.entry;
|
|
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
|
|
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
|
|
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
|
|
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
import static org.mockito.ArgumentMatchers.eq;
|
|
import static org.mockito.ArgumentMatchers.eq;
|
|
import static org.mockito.Mockito.mock;
|
|
import static org.mockito.Mockito.mock;
|
|
|
|
+import static org.mockito.Mockito.times;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.when;
|
|
import static org.mockito.Mockito.when;
|
|
|
|
|
|
@@ -61,6 +68,7 @@ import static org.mockito.Mockito.when;
|
|
*
|
|
*
|
|
* @author Alexey Nesterov
|
|
* @author Alexey Nesterov
|
|
* @author Joe Grandja
|
|
* @author Joe Grandja
|
|
|
|
+ * @author Anoop Garlapati
|
|
* @since 0.0.3
|
|
* @since 0.0.3
|
|
*/
|
|
*/
|
|
public class OAuth2RefreshTokenAuthenticationProviderTests {
|
|
public class OAuth2RefreshTokenAuthenticationProviderTests {
|
|
@@ -156,6 +164,72 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
|
|
assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken());
|
|
assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken());
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @Test
|
|
|
|
+ public void authenticateWhenValidRefreshTokenThenReturnIdToken() {
|
|
|
|
+ RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
|
|
|
|
+ OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
|
|
|
|
+ when(this.authorizationService.findByToken(
|
|
|
|
+ eq(authorization.getRefreshToken().getToken().getTokenValue()),
|
|
|
|
+ eq(OAuth2TokenType.REFRESH_TOKEN)))
|
|
|
|
+ .thenReturn(authorization);
|
|
|
|
+
|
|
|
|
+ OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
|
|
|
|
+ OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
|
|
|
|
+ authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
|
|
|
|
+
|
|
|
|
+ OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
|
|
|
|
+ (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
|
|
|
|
+
|
|
|
|
+ ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
|
|
|
|
+ verify(this.jwtCustomizer, times(2)).customize(jwtEncodingContextCaptor.capture());
|
|
|
|
+ // Access Token context
|
|
|
|
+ JwtEncodingContext accessTokenContext = jwtEncodingContextCaptor.getAllValues().get(0);
|
|
|
|
+ assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
|
|
|
|
+ assertThat(accessTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
|
|
|
|
+ assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization);
|
|
|
|
+ assertThat(accessTokenContext.getAuthorizedScopes())
|
|
|
|
+ .isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME));
|
|
|
|
+ assertThat(accessTokenContext.getTokenType()).isEqualTo(OAuth2TokenType.ACCESS_TOKEN);
|
|
|
|
+ assertThat(accessTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
|
|
|
|
+ assertThat(accessTokenContext.<OAuth2AuthorizationGrantAuthenticationToken>getAuthorizationGrant()).isEqualTo(authentication);
|
|
|
|
+ assertThat(accessTokenContext.getHeaders()).isNotNull();
|
|
|
|
+ assertThat(accessTokenContext.getClaims()).isNotNull();
|
|
|
|
+ Map<String, Object> claims = new HashMap<>();
|
|
|
|
+ accessTokenContext.getClaims().claims(claims::putAll);
|
|
|
|
+ assertThat(claims).flatExtracting(OAuth2ParameterNames.SCOPE)
|
|
|
|
+ .containsExactlyInAnyOrder(OidcScopes.OPENID, "scope1");
|
|
|
|
+ // ID Token context
|
|
|
|
+ JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1);
|
|
|
|
+ assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
|
|
|
|
+ assertThat(idTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
|
|
|
|
+ assertThat(idTokenContext.getAuthorization()).isEqualTo(authorization);
|
|
|
|
+ assertThat(idTokenContext.getAuthorizedScopes())
|
|
|
|
+ .isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME));
|
|
|
|
+ assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN);
|
|
|
|
+ assertThat(idTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
|
|
|
|
+ assertThat(idTokenContext.<OAuth2AuthorizationGrantAuthenticationToken>getAuthorizationGrant()).isEqualTo(authentication);
|
|
|
|
+ assertThat(idTokenContext.getHeaders()).isNotNull();
|
|
|
|
+ assertThat(idTokenContext.getClaims()).isNotNull();
|
|
|
|
+
|
|
|
|
+ verify(this.jwtEncoder, times(2)).encode(any(), any()); // Access token and ID Token
|
|
|
|
+
|
|
|
|
+ ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
|
|
|
|
+ verify(this.authorizationService).save(authorizationCaptor.capture());
|
|
|
|
+ OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
|
|
|
|
+
|
|
|
|
+ assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
|
|
|
|
+ assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
|
|
|
|
+ assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
|
|
|
|
+ assertThat(updatedAuthorization.getAccessToken()).isNotEqualTo(authorization.getAccessToken());
|
|
|
|
+ OAuth2Authorization.Token<OidcIdToken> idToken = updatedAuthorization.getToken(OidcIdToken.class);
|
|
|
|
+ assertThat(idToken).isNotNull();
|
|
|
|
+ assertThat(accessTokenAuthentication.getAdditionalParameters())
|
|
|
|
+ .containsExactly(entry(OidcParameterNames.ID_TOKEN, idToken.getToken().getTokenValue()));
|
|
|
|
+ assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
|
|
|
|
+ // By default, refresh token is reused
|
|
|
|
+ assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken());
|
|
|
|
+ }
|
|
|
|
+
|
|
@Test
|
|
@Test
|
|
public void authenticateWhenReuseRefreshTokensFalseThenReturnNewRefreshToken() {
|
|
public void authenticateWhenReuseRefreshTokensFalseThenReturnNewRefreshToken() {
|
|
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
|
|
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
|