Ver código fonte

Provide configuration for refresh token generator

Closes gh-377
Joe Grandja 4 anos atrás
pai
commit
3ea7d8c9b6

+ 26 - 2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -16,15 +16,21 @@
 package org.springframework.security.oauth2.server.authorization.authentication;
 package org.springframework.security.oauth2.server.authorization.authentication;
 
 
 import java.security.Principal;
 import java.security.Principal;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Base64;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Map;
 import java.util.Set;
 import java.util.Set;
+import java.util.function.Supplier;
 
 
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
+import org.springframework.security.crypto.keygen.StringKeyGenerator;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
@@ -74,9 +80,12 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 			new OAuth2TokenType(OAuth2ParameterNames.CODE);
 			new OAuth2TokenType(OAuth2ParameterNames.CODE);
 	private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE =
 	private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE =
 			new OAuth2TokenType(OidcParameterNames.ID_TOKEN);
 			new OAuth2TokenType(OidcParameterNames.ID_TOKEN);
+	private static final StringKeyGenerator DEFAULT_REFRESH_TOKEN_GENERATOR =
+			new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
 	private final OAuth2AuthorizationService authorizationService;
 	private final OAuth2AuthorizationService authorizationService;
 	private final JwtEncoder jwtEncoder;
 	private final JwtEncoder jwtEncoder;
 	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
 	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
+	private Supplier<String> refreshTokenGenerator = DEFAULT_REFRESH_TOKEN_GENERATOR::generateKey;
 	private ProviderSettings providerSettings;
 	private ProviderSettings providerSettings;
 
 
 	/**
 	/**
@@ -97,6 +106,16 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 		this.jwtCustomizer = jwtCustomizer;
 		this.jwtCustomizer = jwtCustomizer;
 	}
 	}
 
 
+	/**
+	 * Sets the {@code Supplier<String>} that generates the value for the {@link OAuth2RefreshToken}.
+	 *
+	 * @param refreshTokenGenerator the {@code Supplier<String>} that generates the value for the {@link OAuth2RefreshToken}
+	 */
+	public void setRefreshTokenGenerator(Supplier<String> refreshTokenGenerator) {
+		Assert.notNull(refreshTokenGenerator, "refreshTokenGenerator cannot be null");
+		this.refreshTokenGenerator = refreshTokenGenerator;
+	}
+
 	@Autowired(required = false)
 	@Autowired(required = false)
 	protected void setProviderSettings(ProviderSettings providerSettings) {
 	protected void setProviderSettings(ProviderSettings providerSettings) {
 		this.providerSettings = providerSettings;
 		this.providerSettings = providerSettings;
@@ -173,8 +192,7 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 
 
 		OAuth2RefreshToken refreshToken = null;
 		OAuth2RefreshToken refreshToken = null;
 		if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) {
 		if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) {
-			refreshToken = OAuth2RefreshTokenAuthenticationProvider.generateRefreshToken(
-					registeredClient.getTokenSettings().getRefreshTokenTimeToLive());
+			refreshToken = generateRefreshToken(registeredClient.getTokenSettings().getRefreshTokenTimeToLive());
 		}
 		}
 
 
 		Jwt jwtIdToken = null;
 		Jwt jwtIdToken = null;
@@ -250,4 +268,10 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 		return OAuth2AuthorizationCodeAuthenticationToken.class.isAssignableFrom(authentication);
 		return OAuth2AuthorizationCodeAuthenticationToken.class.isAssignableFrom(authentication);
 	}
 	}
 
 
+	private OAuth2RefreshToken generateRefreshToken(Duration tokenTimeToLive) {
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = issuedAt.plus(tokenTimeToLive);
+		return new OAuth2RefreshToken(this.refreshTokenGenerator.get(), issuedAt, expiresAt);
+	}
+
 }
 }

+ 16 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java

@@ -23,6 +23,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Map;
 import java.util.Set;
 import java.util.Set;
+import java.util.function.Supplier;
 
 
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.authentication.AuthenticationProvider;
@@ -73,10 +74,12 @@ import static org.springframework.security.oauth2.server.authorization.authentic
  */
  */
 public final class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider {
 public final class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider {
 	private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN);
 	private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN);
-	private static final StringKeyGenerator TOKEN_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
+	private static final StringKeyGenerator DEFAULT_REFRESH_TOKEN_GENERATOR =
+			new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
 	private final OAuth2AuthorizationService authorizationService;
 	private final OAuth2AuthorizationService authorizationService;
 	private final JwtEncoder jwtEncoder;
 	private final JwtEncoder jwtEncoder;
 	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
 	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
+	private Supplier<String> refreshTokenGenerator = DEFAULT_REFRESH_TOKEN_GENERATOR::generateKey;
 	private ProviderSettings providerSettings;
 	private ProviderSettings providerSettings;
 
 
 	/**
 	/**
@@ -98,6 +101,16 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
 		this.jwtCustomizer = jwtCustomizer;
 		this.jwtCustomizer = jwtCustomizer;
 	}
 	}
 
 
+	/**
+	 * Sets the {@code Supplier<String>} that generates the value for the {@link OAuth2RefreshToken}.
+	 *
+	 * @param refreshTokenGenerator the {@code Supplier<String>} that generates the value for the {@link OAuth2RefreshToken}
+	 */
+	public void setRefreshTokenGenerator(Supplier<String> refreshTokenGenerator) {
+		Assert.notNull(refreshTokenGenerator, "refreshTokenGenerator cannot be null");
+		this.refreshTokenGenerator = refreshTokenGenerator;
+	}
+
 	@Autowired(required = false)
 	@Autowired(required = false)
 	protected void setProviderSettings(ProviderSettings providerSettings) {
 	protected void setProviderSettings(ProviderSettings providerSettings) {
 		this.providerSettings = providerSettings;
 		this.providerSettings = providerSettings;
@@ -246,9 +259,9 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
 		return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication);
 		return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication);
 	}
 	}
 
 
-	static OAuth2RefreshToken generateRefreshToken(Duration tokenTimeToLive) {
+	private OAuth2RefreshToken generateRefreshToken(Duration tokenTimeToLive) {
 		Instant issuedAt = Instant.now();
 		Instant issuedAt = Instant.now();
 		Instant expiresAt = issuedAt.plus(tokenTimeToLive);
 		Instant expiresAt = issuedAt.plus(tokenTimeToLive);
-		return new OAuth2RefreshToken(TOKEN_GENERATOR.generateKey(), issuedAt, expiresAt);
+		return new OAuth2RefreshToken(this.refreshTokenGenerator.get(), issuedAt, expiresAt);
 	}
 	}
 }
 }

+ 40 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -23,6 +23,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Map;
 import java.util.Set;
 import java.util.Set;
+import java.util.function.Supplier;
 
 
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
@@ -61,6 +62,7 @@ import static org.assertj.core.api.Assertions.entry;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.mockito.Mockito.when;
@@ -110,6 +112,13 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 				.hasMessage("jwtCustomizer cannot be null");
 				.hasMessage("jwtCustomizer cannot be null");
 	}
 	}
 
 
+	@Test
+	public void setRefreshTokenGeneratorWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authenticationProvider.setRefreshTokenGenerator(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("refreshTokenGenerator cannot be null");
+	}
+
 	@Test
 	@Test
 	public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturnTrue() {
 	public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturnTrue() {
 		assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue();
 		assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue();
@@ -440,6 +449,37 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		assertThat(accessTokenAuthentication.getRefreshToken()).isNull();
 		assertThat(accessTokenAuthentication.getRefreshToken()).isNull();
 	}
 	}
 
 
+	@Test
+	public void authenticateWhenCustomRefreshTokenGeneratorThenUsed() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
+				.thenReturn(authorization);
+
+		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
+
+		@SuppressWarnings("unchecked")
+		Supplier<String> refreshTokenGenerator = spy(new Supplier<String>() {
+			@Override
+			public String get() {
+				return "custom-refresh-token";
+			}
+		});
+		this.authenticationProvider.setRefreshTokenGenerator(refreshTokenGenerator);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+				OAuth2AuthorizationRequest.class.getName());
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null);
+
+		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+		verify(refreshTokenGenerator).get();
+		assertThat(accessTokenAuthentication.getRefreshToken().getTokenValue()).isEqualTo(refreshTokenGenerator.get());
+	}
+
 	private static Jwt createJwt() {
 	private static Jwt createJwt() {
 		Instant issuedAt = Instant.now();
 		Instant issuedAt = Instant.now();
 		Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);
 		Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);

+ 40 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java

@@ -23,6 +23,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Map;
 import java.util.Set;
 import java.util.Set;
+import java.util.function.Supplier;
 
 
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
@@ -59,6 +60,7 @@ import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.mockito.Mockito.when;
@@ -111,6 +113,13 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 				.hasMessage("jwtCustomizer cannot be null");
 				.hasMessage("jwtCustomizer cannot be null");
 	}
 	}
 
 
+	@Test
+	public void setRefreshTokenGeneratorWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authenticationProvider.setRefreshTokenGenerator(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("refreshTokenGenerator cannot be null");
+	}
+
 	@Test
 	@Test
 	public void supportsWhenSupportedAuthenticationThenTrue() {
 	public void supportsWhenSupportedAuthenticationThenTrue() {
 		assertThat(this.authenticationProvider.supports(OAuth2RefreshTokenAuthenticationToken.class)).isTrue();
 		assertThat(this.authenticationProvider.supports(OAuth2RefreshTokenAuthenticationToken.class)).isTrue();
@@ -281,6 +290,37 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(requestedScopes);
 		assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(requestedScopes);
 	}
 	}
 
 
+	@Test
+	public void authenticateWhenCustomRefreshTokenGeneratorThenUsed() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.tokenSettings(TokenSettings.builder().reuseRefreshTokens(false).build())
+				.build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		when(this.authorizationService.findByToken(
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
+				eq(OAuth2TokenType.REFRESH_TOKEN)))
+				.thenReturn(authorization);
+
+		@SuppressWarnings("unchecked")
+		Supplier<String> refreshTokenGenerator = spy(new Supplier<String>() {
+			@Override
+			public String get() {
+				return "custom-refresh-token";
+			}
+		});
+		this.authenticationProvider.setRefreshTokenGenerator(refreshTokenGenerator);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
+		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
+
+		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+		verify(refreshTokenGenerator).get();
+		assertThat(accessTokenAuthentication.getRefreshToken().getTokenValue()).isEqualTo(refreshTokenGenerator.get());
+	}
+
 	@Test
 	@Test
 	public void authenticateWhenRequestedScopesNotAuthorizedThenThrowOAuth2AuthenticationException() {
 	public void authenticateWhenRequestedScopesNotAuthorizedThenThrowOAuth2AuthenticationException() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();