浏览代码

Ensure ID Token is updated after refresh token

Signed-off-by: Hao <kyrieeeee2@gmail.com>
Hao 6 月之前
父节点
当前提交
fc1469ad5e

+ 15 - 1
config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java

@@ -34,6 +34,9 @@ import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
 import org.springframework.beans.factory.support.BeanDefinitionBuilder;
 import org.springframework.beans.factory.support.BeanDefinitionRegistry;
 import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
+import org.springframework.context.ApplicationContext;
+import org.springframework.context.ApplicationContextAware;
+import org.springframework.context.ApplicationEventPublisher;
 import org.springframework.context.annotation.AnnotationBeanNameGenerator;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
@@ -160,7 +163,7 @@ final class OAuth2ClientConfiguration {
 	 * @since 6.2.0
 	 */
 	static final class OAuth2AuthorizedClientManagerRegistrar
-			implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
+			implements ApplicationContextAware, BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
 
 		static final String BEAN_NAME = "authorizedClientManagerRegistrar";
 
@@ -179,6 +182,8 @@ final class OAuth2ClientConfiguration {
 
 		private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator();
 
+		private ApplicationEventPublisher eventPublisher;
+
 		private ListableBeanFactory beanFactory;
 
 		@Override
@@ -302,6 +307,10 @@ final class OAuth2ClientConfiguration {
 				authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
 			}
 
+			if (this.eventPublisher != null) {
+				authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher);
+			}
+
 			return authorizedClientProvider;
 		}
 
@@ -423,6 +432,11 @@ final class OAuth2ClientConfiguration {
 			return objectProvider.getIfAvailable();
 		}
 
+		@Override
+		public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
+			this.eventPublisher = applicationContext;
+		}
+
 	}
 
 }

+ 10 - 0
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java

@@ -57,6 +57,7 @@ import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationC
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
 import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider;
+import org.springframework.security.oauth2.client.oidc.authentication.RefreshOidcIdTokenHandler;
 import org.springframework.security.oauth2.client.oidc.session.InMemoryOidcSessionRegistry;
 import org.springframework.security.oauth2.client.oidc.session.OidcSessionInformation;
 import org.springframework.security.oauth2.client.oidc.session.OidcSessionRegistry;
@@ -394,6 +395,15 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 				oidcAuthorizationCodeAuthenticationProvider.setAuthoritiesMapper(userAuthoritiesMapper);
 			}
 			http.authenticationProvider(this.postProcess(oidcAuthorizationCodeAuthenticationProvider));
+
+			RefreshOidcIdTokenHandler refreshOidcIdTokenHandler = new RefreshOidcIdTokenHandler();
+			if (this.getSecurityContextHolderStrategy() != null) {
+				refreshOidcIdTokenHandler.setSecurityContextHolderStrategy(this.getSecurityContextHolderStrategy());
+			}
+			if (jwtDecoderFactory != null) {
+				refreshOidcIdTokenHandler.setJwtDecoderFactory(jwtDecoderFactory);
+			}
+			registerDelegateApplicationListener(refreshOidcIdTokenHandler);
 		}
 		else {
 			http.authenticationProvider(new OidcAuthenticationRequestChecker());

+ 7 - 0
config/src/main/java/org/springframework/security/config/http/OAuth2AuthorizedClientManagerRegistrar.java

@@ -34,6 +34,7 @@ import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
 import org.springframework.beans.factory.support.BeanDefinitionBuilder;
 import org.springframework.beans.factory.support.BeanDefinitionRegistry;
 import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
+import org.springframework.context.ApplicationEventPublisher;
 import org.springframework.context.annotation.AnnotationBeanNameGenerator;
 import org.springframework.core.ResolvableType;
 import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider;
@@ -197,6 +198,12 @@ final class OAuth2AuthorizedClientManagerRegistrar implements BeanDefinitionRegi
 			authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
 		}
 
+		ApplicationEventPublisher applicationEventPublisher = getBeanOfType(
+				ResolvableType.forClass(ApplicationEventPublisher.class));
+		if (applicationEventPublisher != null) {
+			authorizedClientProvider.setApplicationEventPublisher(applicationEventPublisher);
+		}
+
 		return authorizedClientProvider;
 	}
 

+ 17 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java

@@ -25,6 +25,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.function.Consumer;
 
+import org.springframework.context.ApplicationEventPublisher;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
 import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
@@ -359,6 +360,8 @@ public final class OAuth2AuthorizedClientProviderBuilder {
 
 		private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient;
 
+		private ApplicationEventPublisher eventPublisher;
+
 		private Duration clockSkew;
 
 		private Clock clock;
@@ -379,6 +382,17 @@ public final class OAuth2AuthorizedClientProviderBuilder {
 			return this;
 		}
 
+		/**
+		 * Sets the {@link ApplicationEventPublisher} used when an access token is
+		 * refreshed.
+		 * @param eventPublisher the {@link ApplicationEventPublisher}
+		 * @return the {@link RefreshTokenGrantBuilder}
+		 */
+		public RefreshTokenGrantBuilder eventPublisher(ApplicationEventPublisher eventPublisher) {
+			this.eventPublisher = eventPublisher;
+			return this;
+		}
+
 		/**
 		 * Sets the maximum acceptable clock skew, which is used when checking the access
 		 * token expiry. An access token is considered expired if
@@ -414,6 +428,9 @@ public final class OAuth2AuthorizedClientProviderBuilder {
 			if (this.accessTokenResponseClient != null) {
 				authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
 			}
+			if (this.eventPublisher != null) {
+				authorizedClientProvider.setApplicationEventPublisher(this.eventPublisher);
+			}
 			if (this.clockSkew != null) {
 				authorizedClientProvider.setClockSkew(this.clockSkew);
 			}

+ 23 - 3
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java

@@ -24,10 +24,13 @@ import java.util.Collections;
 import java.util.HashSet;
 import java.util.Set;
 
+import org.springframework.context.ApplicationEventPublisher;
+import org.springframework.context.ApplicationEventPublisherAware;
 import org.springframework.lang.Nullable;
 import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
+import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Token;
@@ -43,10 +46,13 @@ import org.springframework.util.Assert;
  * @see OAuth2AuthorizedClientProvider
  * @see DefaultRefreshTokenTokenResponseClient
  */
-public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider {
+public final class RefreshTokenOAuth2AuthorizedClientProvider
+		implements OAuth2AuthorizedClientProvider, ApplicationEventPublisherAware {
 
 	private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient();
 
+	private ApplicationEventPublisher eventPublisher;
+
 	private Duration clockSkew = Duration.ofSeconds(60);
 
 	private Clock clock = Clock.systemUTC();
@@ -91,8 +97,17 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A
 				authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(),
 				authorizedClient.getRefreshToken(), scopes);
 		OAuth2AccessTokenResponse tokenResponse = getTokenResponse(authorizedClient, refreshTokenGrantRequest);
-		return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(),
-				context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());
+
+		OAuth2AuthorizedClient updatedOAuth2AuthorizedClient = new OAuth2AuthorizedClient(
+				authorizedClient.getClientRegistration(), context.getPrincipal().getName(),
+				tokenResponse.getAccessToken(), tokenResponse.getRefreshToken());
+
+		if (this.eventPublisher != null) {
+			this.eventPublisher
+				.publishEvent(new OAuth2TokenRefreshedEvent(this, updatedOAuth2AuthorizedClient, tokenResponse));
+		}
+
+		return updatedOAuth2AuthorizedClient;
 	}
 
 	private OAuth2AccessTokenResponse getTokenResponse(OAuth2AuthorizedClient authorizedClient,
@@ -149,4 +164,9 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A
 		this.clock = clock;
 	}
 
+	@Override
+	public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
+		this.eventPublisher = applicationEventPublisher;
+	}
+
 }

+ 47 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/event/OAuth2TokenRefreshedEvent.java

@@ -0,0 +1,47 @@
+/*
+ * Copyright 2002-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.event;
+
+import org.springframework.context.ApplicationEvent;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+
+/**
+ * An event that is published when an OAuth2 access token is refreshed.
+ */
+public class OAuth2TokenRefreshedEvent extends ApplicationEvent {
+
+	private final OAuth2AuthorizedClient authorizedClient;
+
+	private final OAuth2AccessTokenResponse accessTokenResponse;
+
+	public OAuth2TokenRefreshedEvent(Object source, OAuth2AuthorizedClient authorizedClient,
+			OAuth2AccessTokenResponse accessTokenResponse) {
+		super(source);
+		this.authorizedClient = authorizedClient;
+		this.accessTokenResponse = accessTokenResponse;
+	}
+
+	public OAuth2AuthorizedClient getAuthorizedClient() {
+		return this.authorizedClient;
+	}
+
+	public OAuth2AccessTokenResponse getAccessTokenResponse() {
+		return this.accessTokenResponse;
+	}
+
+}

+ 139 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandler.java

@@ -0,0 +1,139 @@
+/*
+ * Copyright 2002-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.oidc.authentication;
+
+import java.util.Map;
+
+import org.springframework.context.ApplicationListener;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+import org.springframework.security.oauth2.core.oidc.OidcScopes;
+import org.springframework.security.oauth2.core.oidc.StandardClaimNames;
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
+import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
+import org.springframework.security.oauth2.core.oidc.user.OidcUser;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtDecoder;
+import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
+import org.springframework.security.oauth2.jwt.JwtException;
+import org.springframework.util.Assert;
+
+/**
+ * An {@link ApplicationListener} that listens for {@link OAuth2TokenRefreshedEvent}s
+ */
+public class RefreshOidcIdTokenHandler implements ApplicationListener<OAuth2TokenRefreshedEvent> {
+
+	private static final String MISSING_ID_TOKEN_ERROR_CODE = "missing_id_token";
+
+	private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";
+
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+		.getContextHolderStrategy();
+
+	private JwtDecoderFactory<ClientRegistration> jwtDecoderFactory = new OidcIdTokenDecoderFactory();
+
+	@Override
+	public void onApplicationEvent(OAuth2TokenRefreshedEvent event) {
+		OAuth2AuthorizedClient authorizedClient = event.getAuthorizedClient();
+
+		if (!authorizedClient.getClientRegistration().getScopes().contains(OidcScopes.OPENID)) {
+			return;
+		}
+
+		Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
+		if (!(authentication instanceof OAuth2AuthenticationToken oauth2Authentication)) {
+			return;
+		}
+		if (!(authentication.getPrincipal() instanceof DefaultOidcUser defaultOidcUser)) {
+			return;
+		}
+
+		OAuth2AccessTokenResponse accessTokenResponse = event.getAccessTokenResponse();
+
+		String idToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN);
+		if (idToken == null || idToken.isBlank()) {
+			OAuth2Error missingIdTokenError = new OAuth2Error(MISSING_ID_TOKEN_ERROR_CODE,
+					"ID token is missing in the token response", null);
+			throw new OAuth2AuthenticationException(missingIdTokenError, missingIdTokenError.toString());
+		}
+
+		ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
+		OidcIdToken refreshedOidcToken = createOidcToken(clientRegistration, accessTokenResponse);
+		updateSecurityContext(oauth2Authentication, defaultOidcUser, refreshedOidcToken);
+	}
+
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
+	/**
+	 * Sets the {@link JwtDecoderFactory} used for {@link OidcIdToken} signature
+	 * verification. The factory returns a {@link JwtDecoder} associated to the provided
+	 * {@link ClientRegistration}.
+	 * @param jwtDecoderFactory the {@link JwtDecoderFactory} used for {@link OidcIdToken}
+	 * signature verification
+	 */
+	public final void setJwtDecoderFactory(JwtDecoderFactory<ClientRegistration> jwtDecoderFactory) {
+		Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
+		this.jwtDecoderFactory = jwtDecoderFactory;
+	}
+
+	private void updateSecurityContext(OAuth2AuthenticationToken oauth2Authentication, DefaultOidcUser defaultOidcUser,
+			OidcIdToken refreshedOidcToken) {
+		OidcUser oidcUser = new DefaultOidcUser(defaultOidcUser.getAuthorities(), refreshedOidcToken,
+				defaultOidcUser.getUserInfo(), StandardClaimNames.SUB);
+
+		SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
+		context.setAuthentication(new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(),
+				oauth2Authentication.getAuthorizedClientRegistrationId()));
+
+		this.securityContextHolderStrategy.setContext(context);
+	}
+
+	private OidcIdToken createOidcToken(ClientRegistration clientRegistration,
+			OAuth2AccessTokenResponse accessTokenResponse) {
+		JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
+		Jwt jwt = getJwt(accessTokenResponse, jwtDecoder);
+		return new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims());
+	}
+
+	private Jwt getJwt(OAuth2AccessTokenResponse accessTokenResponse, JwtDecoder jwtDecoder) {
+		try {
+			Map<String, Object> parameters = accessTokenResponse.getAdditionalParameters();
+			return jwtDecoder.decode((String) parameters.get(OidcParameterNames.ID_TOKEN));
+		}
+		catch (JwtException ex) {
+			OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null);
+			throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex);
+		}
+	}
+
+}

+ 53 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java

@@ -25,10 +25,12 @@ import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.mockito.ArgumentCaptor;
 
+import org.springframework.context.ApplicationEventPublisher;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
+import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -251,4 +253,55 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests {
 					+ OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'");
 	}
 
+	@Test
+	public void shouldPublishEventWhenTokenRefreshed() {
+		OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher();
+		this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher);
+		// @formatter:off
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
+				.accessTokenResponse()
+				.refreshToken("new-refresh-token")
+				.build();
+		// @formatter:on
+		given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
+		// @formatter:off
+		OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
+				.withAuthorizedClient(this.authorizedClient)
+				.principal(this.principal)
+				.build();
+		// @formatter:on
+		this.authorizedClientProvider.authorize(authorizationContext);
+		assertThat(eventPublisher.flag).isTrue();
+	}
+
+	@Test
+	public void shouldNotPublishEventWhenTokenNotRefreshed() {
+		OAuth2TokenRefreshedAwareEventPublisher eventPublisher = new OAuth2TokenRefreshedAwareEventPublisher();
+		this.authorizedClientProvider.setApplicationEventPublisher(eventPublisher);
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
+				this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken());
+		// @formatter:off
+		OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
+				.withAuthorizedClient(authorizedClient)
+				.principal(this.principal)
+				.build();
+		// @formatter:on
+		this.authorizedClientProvider.authorize(authorizationContext);
+		assertThat(eventPublisher.flag).isFalse();
+	}
+
+	private static class OAuth2TokenRefreshedAwareEventPublisher implements ApplicationEventPublisher {
+
+		Boolean flag = false;
+
+		@Override
+		public void publishEvent(Object event) {
+			if (OAuth2TokenRefreshedEvent.class.isAssignableFrom(event.getClass())) {
+				this.flag = true;
+			}
+		}
+
+	}
+
 }

+ 284 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/RefreshOidcIdTokenHandlerTests.java

@@ -0,0 +1,284 @@
+/*
+ * Copyright 2002-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.oidc.authentication;
+
+import java.time.Instant;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
+import org.springframework.security.oauth2.client.event.OAuth2TokenRefreshedEvent;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+import org.springframework.security.oauth2.core.oidc.OidcScopes;
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
+import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
+import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtDecoder;
+import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
+import org.springframework.security.oauth2.jwt.JwtException;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+
+class RefreshOidcIdTokenHandlerTests {
+
+	private static final String EXISTING_ID_TOKEN_VALUE = "id-token-value";
+
+	private static final String REFRESHED_ID_TOKEN_VALUE = "new-id-token-value";
+
+	private static final String EXISTING_ACCESS_TOKEN_VALUE = "token-value";
+
+	private static final String REFRESHED_ACCESS_TOKEN_VALUE = "new-token-value";
+
+	private RefreshOidcIdTokenHandler handler;
+
+	private RefreshTokenOAuth2AuthorizedClientProvider provider;
+
+	private ClientRegistration clientRegistration;
+
+	private OAuth2AuthorizedClient authorizedClient;
+
+	private JwtDecoder jwtDecoder;
+
+	private SecurityContext securityContext;
+
+	private OidcIdToken existingIdToken;
+
+	@BeforeEach
+	void setUp() {
+		this.handler = new RefreshOidcIdTokenHandler();
+
+		this.clientRegistration = createClientRegistrationWithScopes(OidcScopes.OPENID);
+		this.authorizedClient = createAuthorizedClient(this.clientRegistration);
+
+		this.provider = mock(RefreshTokenOAuth2AuthorizedClientProvider.class);
+
+		JwtDecoderFactory<ClientRegistration> jwtDecoderFactory = mock(JwtDecoderFactory.class);
+		this.jwtDecoder = mock(JwtDecoder.class);
+		SecurityContextHolderStrategy securityContextHolderStrategy = mock(SecurityContextHolderStrategy.class);
+		this.securityContext = mock(SecurityContext.class);
+
+		this.handler.setJwtDecoderFactory(jwtDecoderFactory);
+		this.handler.setSecurityContextHolderStrategy(securityContextHolderStrategy);
+
+		given(jwtDecoderFactory.createDecoder(any())).willReturn(this.jwtDecoder);
+		given(securityContextHolderStrategy.createEmptyContext()).willReturn(this.securityContext);
+		given(securityContextHolderStrategy.getContext()).willReturn(this.securityContext);
+
+		Map<String, Object> claims = new HashMap<>();
+		claims.put("sub", "subject");
+		Jwt existingIdTokenJwt = new Jwt(EXISTING_ID_TOKEN_VALUE, Instant.now(), Instant.now().plusSeconds(3600),
+				Map.of("alg", "RS256"), claims);
+		Jwt refreshedIdTokenJwt = new Jwt(REFRESHED_ID_TOKEN_VALUE, Instant.now(), Instant.now().plusSeconds(3600),
+				Map.of("alg", "RS256"), claims);
+
+		this.existingIdToken = new OidcIdToken(existingIdTokenJwt.getTokenValue(), existingIdTokenJwt.getIssuedAt(),
+				existingIdTokenJwt.getExpiresAt(), existingIdTokenJwt.getClaims());
+
+		given(this.jwtDecoder.decode(existingIdTokenJwt.getTokenValue())).willReturn(existingIdTokenJwt);
+		given(this.jwtDecoder.decode(refreshedIdTokenJwt.getTokenValue())).willReturn(refreshedIdTokenJwt);
+	}
+
+	@Test
+	void handleEventWhenValidIdTokenThenUpdatesSecurityContext() {
+
+		DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"),
+				this.existingIdToken);
+		OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser,
+				existingUser.getAuthorities(), "registration-id");
+		given(this.securityContext.getAuthentication()).willReturn(existingAuth);
+
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
+			.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
+			.tokenType(OAuth2AccessToken.TokenType.BEARER)
+			.expiresIn(3600)
+			.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE))
+			.build();
+
+		OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
+				accessTokenResponse);
+		this.handler.onApplicationEvent(event);
+
+		ArgumentCaptor<OAuth2AuthenticationToken> authenticationCaptor = ArgumentCaptor
+			.forClass(OAuth2AuthenticationToken.class);
+		verify(this.securityContext).setAuthentication(authenticationCaptor.capture());
+
+		OAuth2AuthenticationToken newAuthentication = authenticationCaptor.getValue();
+		assertThat(newAuthentication.getPrincipal()).isInstanceOf(DefaultOidcUser.class);
+		DefaultOidcUser newUser = (DefaultOidcUser) newAuthentication.getPrincipal();
+		assertThat(newUser.getIdToken().getTokenValue()).isEqualTo(REFRESHED_ID_TOKEN_VALUE);
+	}
+
+	@Test
+	void handleEventWhenAuthorizedClientIsNotOidcThenDoesNothing() {
+
+		this.clientRegistration = createClientRegistrationWithScopes("read");
+		this.authorizedClient = createAuthorizedClient(this.clientRegistration);
+
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
+			.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
+			.tokenType(OAuth2AccessToken.TokenType.BEARER)
+			.expiresIn(3600)
+			.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE))
+			.build();
+
+		OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
+				accessTokenResponse);
+
+		this.handler.onApplicationEvent(event);
+
+		verify(this.securityContext, never()).setAuthentication(any());
+		verify(this.jwtDecoder, never()).decode(any());
+	}
+
+	@Test
+	void handleEventWhenAuthenticationNotOAuth2AuthenticationTokenThenDoesNothing() {
+
+		given(this.securityContext.getAuthentication()).willReturn(mock(TestingAuthenticationToken.class));
+
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
+			.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
+			.tokenType(OAuth2AccessToken.TokenType.BEARER)
+			.expiresIn(3600)
+			.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE))
+			.build();
+
+		OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
+				accessTokenResponse);
+
+		this.handler.onApplicationEvent(event);
+
+		verify(this.securityContext, never()).setAuthentication(any());
+	}
+
+	@Test
+	void handleEventWhenNotOidcUserThenDoesNothing() {
+
+		OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(
+				new DefaultOAuth2User(Collections.emptySet(),
+						Collections.singletonMap("custom-attribute", "test-subject"), "custom-attribute"),
+				AuthorityUtils.createAuthorityList("ROLE_USER"), "registration-id");
+		given(this.securityContext.getAuthentication()).willReturn(existingAuth);
+
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
+			.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
+			.tokenType(OAuth2AccessToken.TokenType.BEARER)
+			.expiresIn(3600)
+			.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, REFRESHED_ID_TOKEN_VALUE))
+			.build();
+
+		OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
+				accessTokenResponse);
+
+		this.handler.onApplicationEvent(event);
+
+		verify(this.securityContext, never()).setAuthentication(any());
+	}
+
+	@Test
+	void handleEventWhenMissingIdTokenThenThrowsException() {
+
+		DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"),
+				this.existingIdToken);
+		OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser,
+				existingUser.getAuthorities(), "registration-id");
+		given(this.securityContext.getAuthentication()).willReturn(existingAuth);
+
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
+			.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
+			.tokenType(OAuth2AccessToken.TokenType.BEARER)
+			.expiresIn(3600)
+			.additionalParameters(new HashMap<>()) // missing ID token
+			.build();
+
+		OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
+				accessTokenResponse);
+
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+			.isThrownBy(() -> this.handler.onApplicationEvent(event))
+			.withMessageContaining("missing_id_token");
+	}
+
+	@Test
+	void handleEventWhenInvalidIdTokenThenThrowsException() {
+
+		DefaultOidcUser existingUser = new DefaultOidcUser(AuthorityUtils.createAuthorityList("ROLE_USER"),
+				this.existingIdToken);
+		OAuth2AuthenticationToken existingAuth = new OAuth2AuthenticationToken(existingUser,
+				existingUser.getAuthorities(), "registration-id");
+		given(this.securityContext.getAuthentication()).willReturn(existingAuth);
+
+		given(this.jwtDecoder.decode(any())).willThrow(new JwtException("Invalid token"));
+
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
+			.withToken(REFRESHED_ACCESS_TOKEN_VALUE)
+			.tokenType(OAuth2AccessToken.TokenType.BEARER)
+			.expiresIn(3600)
+			.additionalParameters(Map.of(OidcParameterNames.ID_TOKEN, "invalid-id-token"))
+			.build();
+
+		OAuth2TokenRefreshedEvent event = new OAuth2TokenRefreshedEvent(this.provider, this.authorizedClient,
+				accessTokenResponse);
+
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+			.isThrownBy(() -> this.handler.onApplicationEvent(event))
+			.withMessageContaining("invalid_id_token");
+	}
+
+	private ClientRegistration createClientRegistrationWithScopes(String... scope) {
+		return ClientRegistration.withRegistrationId("registration-id")
+			.clientId("client-id")
+			.clientSecret("secret")
+			.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+			.redirectUri("http://localhost")
+			.scope(scope)
+			.authorizationUri("https://provider.com/oauth2/authorize")
+			.tokenUri("https://provider.com/oauth2/token")
+			.jwkSetUri("https://provider.com/jwk")
+			.userInfoUri("https://provider.com/user")
+			.build();
+	}
+
+	private static OAuth2AuthorizedClient createAuthorizedClient(ClientRegistration clientRegistration) {
+		return new OAuth2AuthorizedClient(clientRegistration, "principal-name",
+				new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, EXISTING_ACCESS_TOKEN_VALUE, Instant.now(),
+						Instant.now().plusSeconds(3600)));
+	}
+
+}