소스 검색

JwtBearerOAuth2AuthorizedClientProvider checks for access token expiry

Fixes gh-9700
Joe Grandja 4 년 전
부모
커밋
761e3a9dd8

+ 50 - 7
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProvider.java

@@ -16,6 +16,10 @@
 
 package org.springframework.security.oauth2.client;
 
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
+
 import org.springframework.lang.Nullable;
 import org.springframework.security.oauth2.client.endpoint.DefaultJwtBearerTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
@@ -23,6 +27,7 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResp
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.util.Assert;
@@ -40,12 +45,18 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
 
 	private OAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = new DefaultJwtBearerTokenResponseClient();
 
+	private Duration clockSkew = Duration.ofSeconds(60);
+
+	private Clock clock = Clock.systemUTC();
+
 	/**
-	 * Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistration()
-	 * client} in the provided {@code context}. Returns {@code null} if authorization is
-	 * not supported, e.g. the client's
-	 * {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} is
-	 * not {@link AuthorizationGrantType#JWT_BEARER jwt-bearer}.
+	 * Attempt to authorize (or re-authorize) the
+	 * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided
+	 * {@code context}. Returns {@code null} if authorization (or re-authorization) is not
+	 * supported, e.g. the client's {@link ClientRegistration#getAuthorizationGrantType()
+	 * authorization grant type} is not {@link AuthorizationGrantType#JWT_BEARER
+	 * jwt-bearer} OR the {@link OAuth2AuthorizedClient#getAccessToken() access token} is
+	 * not expired.
 	 * @param context the context that holds authorization-specific state for the client
 	 * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not
 	 * supported
@@ -59,8 +70,9 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
 			return null;
 		}
 		OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient();
-		if (authorizedClient != null) {
-			// Client is already authorized
+		if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) {
+			// If client is already authorized but access token is NOT expired than no
+			// need for re-authorization
 			return null;
 		}
 		if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
@@ -95,6 +107,10 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
 		}
 	}
 
+	private boolean hasTokenExpired(OAuth2Token token) {
+		return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew));
+	}
+
 	/**
 	 * Sets the client used when requesting an access token credential at the Token
 	 * Endpoint for the {@code jwt-bearer} grant.
@@ -107,4 +123,31 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
 		this.accessTokenResponseClient = accessTokenResponseClient;
 	}
 
+	/**
+	 * Sets the maximum acceptable clock skew, which is used when checking the
+	 * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is
+	 * 60 seconds.
+	 *
+	 * <p>
+	 * An access token is considered expired if
+	 * {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time
+	 * {@code clock#instant()}.
+	 * @param clockSkew the maximum acceptable clock skew
+	 */
+	public void setClockSkew(Duration clockSkew) {
+		Assert.notNull(clockSkew, "clockSkew cannot be null");
+		Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0");
+		this.clockSkew = clockSkew;
+	}
+
+	/**
+	 * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access
+	 * token expiry.
+	 * @param clock the clock
+	 */
+	public void setClock(Clock clock) {
+		Assert.notNull(clock, "clock cannot be null");
+		this.clock = clock;
+	}
+
 }

+ 81 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProviderTests.java

@@ -16,6 +16,9 @@
 
 package org.springframework.security.oauth2.client;
 
+import java.time.Duration;
+import java.time.Instant;
+
 import org.junit.Before;
 import org.junit.Test;
 
@@ -27,6 +30,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
@@ -83,6 +87,33 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
 				.withMessage("accessTokenResponseClient cannot be null");
 	}
 
+	@Test
+	public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null))
+				.withMessage("clockSkew cannot be null");
+		// @formatter:on
+	}
+
+	@Test
+	public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1)))
+				.withMessage("clockSkew must be >= 0");
+		// @formatter:on
+	}
+
+	@Test
+	public void setClockWhenNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.authorizedClientProvider.setClock(null))
+				.withMessage("clock cannot be null");
+		// @formatter:on
+	}
+
 	@Test
 	public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
 		// @formatter:off
@@ -105,7 +136,7 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
 	}
 
 	@Test
-	public void authorizeWhenJwtBearerAndAuthorizedThenNotAuthorized() {
+	public void authorizeWhenJwtBearerAndTokenNotExpiredThenNotReauthorize() {
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
 				this.principal.getName(), TestOAuth2AccessTokens.scopes("read", "write"));
 		// @formatter:off
@@ -117,6 +148,55 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
 		assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
 	}
 
+	@Test
+	public void authorizeWhenJwtBearerAndTokenExpiredThenReauthorize() {
+		Instant now = Instant.now();
+		Instant issuedAt = now.minus(Duration.ofMinutes(60));
+		Instant expiresAt = now.minus(Duration.ofMinutes(30));
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234",
+				issuedAt, expiresAt);
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
+				this.principal.getName(), accessToken);
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
+		// @formatter:off
+		OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
+				.withAuthorizedClient(authorizedClient)
+				.principal(this.principal)
+				.build();
+		// @formatter:on
+		authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
+		assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
+		assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
+		assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
+	}
+
+	@Test
+	public void authorizeWhenJwtBearerAndTokenNotExpiredButClockSkewForcesExpiryThenReauthorize() {
+		Instant now = Instant.now();
+		Instant issuedAt = now.minus(Duration.ofMinutes(60));
+		Instant expiresAt = now.plus(Duration.ofMinutes(1));
+		OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				"access-token-1234", issuedAt, expiresAt);
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
+				this.principal.getName(), expiresInOneMinAccessToken);
+		// Shorten the lifespan of the access token by 90 seconds, which will ultimately
+		// force it to expire on the client
+		this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90));
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
+		// @formatter:off
+		OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
+				.withAuthorizedClient(authorizedClient)
+				.principal(this.principal)
+				.build();
+		// @formatter:on
+		OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
+		assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
+		assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
+		assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
+	}
+
 	@Test
 	public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() {
 		// @formatter:off