Bladeren bron

Add test for refresh_token grant with public client

Related gh-1432
Joe Grandja 1 jaar geleden
bovenliggende
commit
faad0be153

+ 160 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2RefreshTokenGrantTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,8 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 
+import jakarta.servlet.http.HttpServletRequest;
+
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.proc.SecurityContext;
@@ -34,6 +36,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
 
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
 import org.springframework.context.annotation.Import;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpStatus;
@@ -43,16 +46,25 @@ import org.springframework.jdbc.core.JdbcTemplate;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
+import org.springframework.lang.Nullable;
 import org.springframework.mock.http.client.MockClientHttpResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.Transient;
 import org.springframework.security.crypto.password.NoOpPasswordEncoder;
 import org.springframework.security.crypto.password.PasswordEncoder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -66,6 +78,7 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository.RegisteredClientParametersMapper;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@@ -77,10 +90,15 @@ import org.springframework.security.oauth2.server.authorization.test.SpringTestC
 import org.springframework.security.oauth2.server.authorization.test.SpringTestContextExtension;
 import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
+import org.springframework.security.web.SecurityFilterChain;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MvcResult;
+import org.springframework.util.Assert;
 import org.springframework.util.LinkedMultiValueMap;
 import org.springframework.util.MultiValueMap;
+import org.springframework.util.StringUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.hamcrest.CoreMatchers.containsString;
@@ -217,6 +235,32 @@ public class OAuth2RefreshTokenGrantTests {
 		assertThat(accessToken.isActive()).isTrue();
 	}
 
+	// gh-1430
+	@Test
+	public void requestWhenRefreshTokenRequestWithPublicClientThenReturnAccessTokenResponse() throws Exception {
+		this.spring.register(AuthorizationServerConfigurationWithPublicClientAuthentication.class).autowire();
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient()
+				.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
+				.build();
+		this.registeredClientRepository.save(registeredClient);
+
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		this.authorizationService.save(authorization);
+
+		this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
+				.params(getRefreshTokenRequestParameters(authorization))
+				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()))
+				.andExpect(status().isOk())
+				.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
+				.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
+				.andExpect(jsonPath("$.access_token").isNotEmpty())
+				.andExpect(jsonPath("$.token_type").isNotEmpty())
+				.andExpect(jsonPath("$.expires_in").isNotEmpty())
+				.andExpect(jsonPath("$.refresh_token").isNotEmpty())
+				.andExpect(jsonPath("$.scope").isNotEmpty());
+	}
+
 	private static MultiValueMap<String, String> getRefreshTokenRequestParameters(OAuth2Authorization authorization) {
 		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 		parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue());
@@ -307,4 +351,119 @@ public class OAuth2RefreshTokenGrantTests {
 		}
 
 	}
+
+	@EnableWebSecurity
+	@Configuration(proxyBeanMethods = false)
+	static class AuthorizationServerConfigurationWithPublicClientAuthentication extends AuthorizationServerConfiguration {
+		// @formatter:off
+		@Bean
+		SecurityFilterChain authorizationServerSecurityFilterChain(
+				HttpSecurity http, RegisteredClientRepository registeredClientRepository) throws Exception {
+
+			OAuth2AuthorizationServerConfigurer authorizationServerConfigurer =
+					new OAuth2AuthorizationServerConfigurer();
+			authorizationServerConfigurer
+					.clientAuthentication(clientAuthentication ->
+							clientAuthentication
+									.authenticationConverter(
+											new PublicClientRefreshTokenAuthenticationConverter())
+									.authenticationProvider(
+											new PublicClientRefreshTokenAuthenticationProvider(registeredClientRepository))
+					);
+			RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher();
+
+			http
+					.securityMatcher(endpointsMatcher)
+					.authorizeHttpRequests(authorize ->
+							authorize.anyRequest().authenticated()
+					)
+					.csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher))
+					.apply(authorizationServerConfigurer);
+			return http.build();
+		}
+		// @formatter:on
+	}
+
+	@Transient
+	private static final class PublicClientRefreshTokenAuthenticationToken extends OAuth2ClientAuthenticationToken {
+
+		private PublicClientRefreshTokenAuthenticationToken(String clientId) {
+			super(clientId, ClientAuthenticationMethod.NONE, null, null);
+		}
+
+		private PublicClientRefreshTokenAuthenticationToken(RegisteredClient registeredClient) {
+			super(registeredClient, ClientAuthenticationMethod.NONE, null);
+		}
+
+	}
+
+	private static final class PublicClientRefreshTokenAuthenticationConverter implements AuthenticationConverter {
+
+		@Nullable
+		@Override
+		public Authentication convert(HttpServletRequest request) {
+			// grant_type (REQUIRED)
+			String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
+			if (!AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(grantType)) {
+				return null;
+			}
+
+			// client_id (REQUIRED)
+			String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
+			if (!StringUtils.hasText(clientId)) {
+				return null;
+			}
+
+			return new PublicClientRefreshTokenAuthenticationToken(clientId);
+		}
+
+	}
+
+	private static final class PublicClientRefreshTokenAuthenticationProvider implements AuthenticationProvider {
+		private final RegisteredClientRepository registeredClientRepository;
+
+		private PublicClientRefreshTokenAuthenticationProvider(RegisteredClientRepository registeredClientRepository) {
+			Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
+			this.registeredClientRepository = registeredClientRepository;
+		}
+
+		@Override
+		public Authentication authenticate(Authentication authentication) throws AuthenticationException {
+			PublicClientRefreshTokenAuthenticationToken publicClientAuthentication =
+					(PublicClientRefreshTokenAuthenticationToken) authentication;
+
+			if (!ClientAuthenticationMethod.NONE.equals(publicClientAuthentication.getClientAuthenticationMethod())) {
+				return null;
+			}
+
+			String clientId = publicClientAuthentication.getPrincipal().toString();
+			RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
+			if (registeredClient == null) {
+				throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
+			}
+
+			if (!registeredClient.getClientAuthenticationMethods().contains(
+					publicClientAuthentication.getClientAuthenticationMethod())) {
+				throwInvalidClient("authentication_method");
+			}
+
+			return new PublicClientRefreshTokenAuthenticationToken(registeredClient);
+		}
+
+		@Override
+		public boolean supports(Class<?> authentication) {
+			return PublicClientRefreshTokenAuthenticationToken.class.isAssignableFrom(authentication);
+		}
+
+		private static void throwInvalidClient(String parameterName) {
+			OAuth2Error error = new OAuth2Error(
+					OAuth2ErrorCodes.INVALID_CLIENT,
+					"Public client authentication failed: " + parameterName,
+					null
+			);
+			throw new OAuth2AuthenticationException(error);
+		}
+
+	}
+
 }