Browse Source

Allow configurable refresh token strategy for authorization_code grant

Closes gh-1430
Joe Grandja 1 year ago
parent
commit
71d923575a

+ 13 - 15
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -38,7 +38,6 @@ import org.springframework.security.core.session.SessionInformation;
 import org.springframework.security.core.session.SessionRegistry;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClaimAccessor;
-import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -213,24 +212,23 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 
 		// ----- Refresh token -----
 		OAuth2RefreshToken refreshToken = null;
-		if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN) &&
-				// Do not issue refresh token to public client
-				!clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) {
-
+		if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) {
 			tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.REFRESH_TOKEN).build();
 			OAuth2Token generatedRefreshToken = this.tokenGenerator.generate(tokenContext);
-			if (!(generatedRefreshToken instanceof OAuth2RefreshToken)) {
-				OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
-						"The token generator failed to generate the refresh token.", ERROR_URI);
-				throw new OAuth2AuthenticationException(error);
-			}
+			if (generatedRefreshToken != null) {
+				if (!(generatedRefreshToken instanceof OAuth2RefreshToken)) {
+					OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
+							"The token generator failed to generate a valid refresh token.", ERROR_URI);
+					throw new OAuth2AuthenticationException(error);
+				}
 
-			if (this.logger.isTraceEnabled()) {
-				this.logger.trace("Generated refresh token");
-			}
+				if (this.logger.isTraceEnabled()) {
+					this.logger.trace("Generated refresh token");
+				}
 
-			refreshToken = (OAuth2RefreshToken) generatedRefreshToken;
-			authorizationBuilder.refreshToken(refreshToken);
+				refreshToken = (OAuth2RefreshToken) generatedRefreshToken;
+				authorizationBuilder.refreshToken(refreshToken);
+			}
 		}
 
 		// ----- ID token -----

+ 17 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2RefreshTokenGenerator.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -21,8 +21,11 @@ import java.util.Base64;
 import org.springframework.lang.Nullable;
 import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
 import org.springframework.security.crypto.keygen.StringKeyGenerator;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
 
 /**
  * An {@link OAuth2TokenGenerator} that generates an {@link OAuth2RefreshToken}.
@@ -42,9 +45,22 @@ public final class OAuth2RefreshTokenGenerator implements OAuth2TokenGenerator<O
 		if (!OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
 			return null;
 		}
+		if (isPublicClientForAuthorizationCodeGrant(context)) {
+			// Do not issue refresh token to public client
+			return null;
+		}
+
 		Instant issuedAt = Instant.now();
 		Instant expiresAt = issuedAt.plus(context.getRegisteredClient().getTokenSettings().getRefreshTokenTimeToLive());
 		return new OAuth2RefreshToken(this.refreshTokenGenerator.generateKey(), issuedAt, expiresAt);
 	}
 
+	private static boolean isPublicClientForAuthorizationCodeGrant(OAuth2TokenContext context) {
+		if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getAuthorizationGrantType()) &&
+				(context.getAuthorizationGrant().getPrincipal() instanceof OAuth2ClientAuthenticationToken clientPrincipal)) {
+			return clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE);
+		}
+		return false;
+	}
+
 }

+ 5 - 3
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -42,6 +42,7 @@ import org.springframework.security.core.session.SessionInformation;
 import org.springframework.security.core.session.SessionRegistry;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2Token;
@@ -371,7 +372,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 	}
 
 	@Test
-	public void authenticateWhenRefreshTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
+	public void authenticateWhenInvalidRefreshTokenGeneratedThenThrowOAuth2AuthenticationException() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
@@ -389,7 +390,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		doAnswer(answer -> {
 			OAuth2TokenContext context = answer.getArgument(0);
 			if (OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
-				return null;
+				return new OAuth2AccessToken(
+						OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
 			} else {
 				return answer.callRealMethod();
 			}
@@ -400,7 +402,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
 				.satisfies(error -> {
 					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
-					assertThat(error.getDescription()).contains("The token generator failed to generate the refresh token.");
+					assertThat(error.getDescription()).contains("The token generator failed to generate a valid refresh token.");
 				});
 	}
 

+ 88 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java

@@ -55,6 +55,7 @@ import org.springframework.jdbc.core.JdbcTemplate;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
+import org.springframework.lang.Nullable;
 import org.springframework.mock.http.client.MockClientHttpResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AuthenticationProvider;
@@ -65,9 +66,12 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
+import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
+import org.springframework.security.crypto.keygen.StringKeyGenerator;
 import org.springframework.security.crypto.password.NoOpPasswordEncoder;
 import org.springframework.security.crypto.password.PasswordEncoder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
@@ -426,6 +430,54 @@ public class OAuth2AuthorizationCodeGrantTests {
 		assertThat(authorizationCodeToken.getMetadata().get(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME)).isEqualTo(true);
 	}
 
+	// gh-1430
+	@Test
+	public void requestWhenPublicClientWithPkceAndCustomRefreshTokenGeneratorThenReturnRefreshToken() throws Exception {
+		this.spring.register(AuthorizationServerConfigurationWithCustomRefreshTokenGenerator.class).autowire();
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient()
+				.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
+				.build();
+		this.registeredClientRepository.save(registeredClient);
+
+		MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
+				.params(getAuthorizationRequestParameters(registeredClient))
+				.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
+				.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
+				.with(user("user")))
+				.andExpect(status().is3xxRedirection())
+				.andReturn();
+		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
+		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED);
+
+		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
+		OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
+		assertThat(authorizationCodeAuthorization).isNotNull();
+		assertThat(authorizationCodeAuthorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
+
+		this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
+				.params(getTokenRequestParameters(registeredClient, authorizationCodeAuthorization))
+				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
+				.param(PkceParameterNames.CODE_VERIFIER, S256_CODE_VERIFIER))
+				.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
+				.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
+				.andExpect(status().isOk())
+				.andExpect(jsonPath("$.access_token").isNotEmpty())
+				.andExpect(jsonPath("$.token_type").isNotEmpty())
+				.andExpect(jsonPath("$.expires_in").isNotEmpty())
+				.andExpect(jsonPath("$.refresh_token").isNotEmpty())
+				.andExpect(jsonPath("$.scope").isNotEmpty());
+
+		OAuth2Authorization authorization = this.authorizationService.findById(authorizationCodeAuthorization.getId());
+		assertThat(authorization).isNotNull();
+		assertThat(authorization.getAccessToken()).isNotNull();
+		assertThat(authorization.getRefreshToken()).isNotNull();
+
+		OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCodeToken = authorization.getToken(OAuth2AuthorizationCode.class);
+		assertThat(authorizationCodeToken).isNotNull();
+		assertThat(authorizationCodeToken.getMetadata().get(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME)).isEqualTo(true);
+	}
+
 	@Test
 	public void requestWhenConfidentialClientWithPkceAndMissingCodeVerifierThenBadRequest() throws Exception {
 		this.spring.register(AuthorizationServerConfiguration.class).autowire();
@@ -896,6 +948,42 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 	}
 
+	@EnableWebSecurity
+	@Import(OAuth2AuthorizationServerConfiguration.class)
+	static class AuthorizationServerConfigurationWithCustomRefreshTokenGenerator extends AuthorizationServerConfiguration {
+
+		@Bean
+		JwtEncoder jwtEncoder() {
+			return jwtEncoder;
+		}
+
+		@Bean
+		OAuth2TokenGenerator<?> tokenGenerator() {
+			JwtGenerator jwtGenerator = new JwtGenerator(jwtEncoder());
+			jwtGenerator.setJwtCustomizer(jwtCustomizer());
+			OAuth2TokenGenerator<OAuth2RefreshToken> refreshTokenGenerator = new CustomRefreshTokenGenerator();
+			return new DelegatingOAuth2TokenGenerator(jwtGenerator, refreshTokenGenerator);
+		}
+
+		private static final class CustomRefreshTokenGenerator implements OAuth2TokenGenerator<OAuth2RefreshToken> {
+			private final StringKeyGenerator refreshTokenGenerator =
+					new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
+
+			@Nullable
+			@Override
+			public OAuth2RefreshToken generate(OAuth2TokenContext context) {
+				if (!OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
+					return null;
+				}
+				Instant issuedAt = Instant.now();
+				Instant expiresAt = issuedAt.plus(context.getRegisteredClient().getTokenSettings().getRefreshTokenTimeToLive());
+				return new OAuth2RefreshToken(this.refreshTokenGenerator.generateKey(), issuedAt, expiresAt);
+			}
+
+		}
+
+	}
+
 	@EnableWebSecurity
 	@Configuration(proxyBeanMethods = false)
 	static class AuthorizationServerConfigurationWithSecurityContextRepository extends AuthorizationServerConfiguration {