Просмотр исходного кода

Add OAuth2Client MockMvc Test Support

Fixes gh-7886
Josh Cummings 5 лет назад
Родитель
Сommit
c367378421

+ 117 - 6
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -432,6 +432,39 @@ public final class SecurityMockMvcRequestPostProcessors {
 		return new OidcLoginRequestPostProcessor(accessToken);
 		return new OidcLoginRequestPostProcessor(accessToken);
 	}
 	}
 
 
+	/**
+	 * Establish an {@link OAuth2AuthorizedClient} in the session. All details are
+	 * declarative and do not require associated tokens to be valid.
+	 *
+	 * <p>
+	 * The support works by associating the authorized client to the HttpServletRequest
+	 * via the {@link HttpSessionOAuth2AuthorizedClientRepository}
+	 * </p>
+	 *
+	 * @return the {@link OAuth2ClientRequestPostProcessor} for additional customization
+	 * @since 5.3
+	 */
+	public static OAuth2ClientRequestPostProcessor oauth2Client() {
+		return new OAuth2ClientRequestPostProcessor();
+	}
+
+	/**
+	 * Establish an {@link OAuth2AuthorizedClient} in the session. All details are
+	 * declarative and do not require associated tokens to be valid.
+	 *
+	 * <p>
+	 * The support works by associating the authorized client to the HttpServletRequest
+	 * via the {@link HttpSessionOAuth2AuthorizedClientRepository}
+	 * </p>
+	 *
+	 * @param registrationId The registration id for the {@link OAuth2AuthorizedClient}
+	 * @return the {@link OAuth2ClientRequestPostProcessor} for additional customization
+	 * @since 5.3
+	 */
+	public static OAuth2ClientRequestPostProcessor oauth2Client(String registrationId) {
+		return new OAuth2ClientRequestPostProcessor(registrationId);
+	}
+
 	/**
 	/**
 	 * Populates the X509Certificate instances onto the request
 	 * Populates the X509Certificate instances onto the request
 	 */
 	 */
@@ -1389,12 +1422,12 @@ public final class SecurityMockMvcRequestPostProcessors {
 			OAuth2User oauth2User = this.oauth2User.get();
 			OAuth2User oauth2User = this.oauth2User.get();
 			OAuth2AuthenticationToken token = new OAuth2AuthenticationToken
 			OAuth2AuthenticationToken token = new OAuth2AuthenticationToken
 					(oauth2User, oauth2User.getAuthorities(), this.clientRegistration.getRegistrationId());
 					(oauth2User, oauth2User.getAuthorities(), this.clientRegistration.getRegistrationId());
-			OAuth2AuthorizedClient client = new OAuth2AuthorizedClient
-					(this.clientRegistration, token.getName(), this.accessToken);
-			OAuth2AuthorizedClientRepository authorizedClientRepository = new HttpSessionOAuth2AuthorizedClientRepository();
-			authorizedClientRepository.saveAuthorizedClient(client, token, request, new MockHttpServletResponse());
 
 
-			return new AuthenticationRequestPostProcessor(token).postProcessRequest(request);
+			request = new AuthenticationRequestPostProcessor(token).postProcessRequest(request);
+			return new OAuth2ClientRequestPostProcessor()
+					.clientRegistration(this.clientRegistration)
+					.accessToken(this.accessToken)
+					.postProcessRequest(request);
 		}
 		}
 
 
 		private ClientRegistration.Builder clientRegistrationBuilder() {
 		private ClientRegistration.Builder clientRegistrationBuilder() {
@@ -1570,6 +1603,84 @@ public final class SecurityMockMvcRequestPostProcessors {
 		}
 		}
 	}
 	}
 
 
+	/**
+	 * @author Josh Cummings
+	 * @since 5.3
+	 */
+	public final static class OAuth2ClientRequestPostProcessor implements RequestPostProcessor {
+		private String registrationId = "test";
+		private ClientRegistration clientRegistration;
+		private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				"access-token", null, null, Collections.singleton("user"));
+
+		private OAuth2ClientRequestPostProcessor() {
+		}
+
+		private OAuth2ClientRequestPostProcessor(String registrationId) {
+			this.registrationId = registrationId;
+			clientRegistration(c -> {});
+		}
+
+		/**
+		 * Use this {@link ClientRegistration}
+		 *
+		 * @param clientRegistration
+		 * @return the {@link OAuth2ClientRequestPostProcessor} for further configuration
+		 */
+		public OAuth2ClientRequestPostProcessor clientRegistration(ClientRegistration clientRegistration) {
+			this.clientRegistration = clientRegistration;
+			return this;
+		}
+
+		/**
+		 * Use this {@link Consumer} to configure a {@link ClientRegistration}
+		 *
+		 * @param clientRegistrationConfigurer the {@link ClientRegistration} configurer
+		 * @return the {@link OAuth2ClientRequestPostProcessor} for further configuration
+		 */
+		public OAuth2ClientRequestPostProcessor clientRegistration
+				(Consumer<ClientRegistration.Builder> clientRegistrationConfigurer) {
+
+			ClientRegistration.Builder builder = clientRegistrationBuilder();
+			clientRegistrationConfigurer.accept(builder);
+			this.clientRegistration = builder.build();
+			return this;
+		}
+
+		/**
+		 * Use this {@link OAuth2AccessToken}
+		 *
+		 * @param accessToken the {@link OAuth2AccessToken} to use
+		 * @return the {@link OAuth2ClientRequestPostProcessor} for further configuration
+		 */
+		public OAuth2ClientRequestPostProcessor accessToken(OAuth2AccessToken accessToken) {
+			this.accessToken = accessToken;
+			return this;
+		}
+
+		@Override
+		public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
+			if (this.clientRegistration == null) {
+				throw new IllegalArgumentException("Please specify a ClientRegistration via one " +
+						"of the clientRegistration methods");
+			}
+			OAuth2AuthorizedClient client = new OAuth2AuthorizedClient
+					(this.clientRegistration, "test-subject", this.accessToken);
+			OAuth2AuthorizedClientRepository authorizedClientRepository =
+					new HttpSessionOAuth2AuthorizedClientRepository();
+			authorizedClientRepository.saveAuthorizedClient(client, null, request, new MockHttpServletResponse());
+			return request;
+		}
+
+		private ClientRegistration.Builder clientRegistrationBuilder() {
+			return ClientRegistration.withRegistrationId(this.registrationId)
+					.authorizationGrantType(AuthorizationGrantType.PASSWORD)
+					.clientId("test-client")
+					.clientSecret("test-secret")
+					.tokenUri("https://idp.example.org/oauth/token");
+		}
+	}
+
 	private SecurityMockMvcRequestPostProcessors() {
 	private SecurityMockMvcRequestPostProcessors() {
 	}
 	}
 }
 }

+ 170 - 0
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOAuth2ClientTests.java

@@ -0,0 +1,170 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.test.web.servlet.request;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Bean;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.config.annotation.web.builders.HttpSecurity;
+import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
+import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.test.context.TestSecurityContextHolder;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
+import org.springframework.test.context.web.WebAppConfiguration;
+import org.springframework.test.web.servlet.MockMvc;
+import org.springframework.test.web.servlet.setup.MockMvcBuilders;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.context.WebApplicationContext;
+import org.springframework.web.servlet.config.annotation.EnableWebMvc;
+
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.mockito.Mockito.mock;
+import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
+import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
+import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.oauth2Client;
+import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity;
+import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
+
+/**
+ * Tests for {@link SecurityMockMvcRequestPostProcessors#oidcLogin()}
+ *
+ * @author Josh Cummings
+ * @since 5.3
+ */
+@RunWith(SpringJUnit4ClassRunner.class)
+@ContextConfiguration
+@WebAppConfiguration
+public class SecurityMockMvcRequestPostProcessorsOAuth2ClientTests {
+	@Autowired
+	WebApplicationContext context;
+
+	MockMvc mvc;
+
+	@Before
+	public void setup() {
+		// @formatter:off
+		this.mvc = MockMvcBuilders
+			.webAppContextSetup(this.context)
+			.apply(springSecurity())
+			.build();
+		// @formatter:on
+	}
+
+	@After
+	public void cleanup() {
+		TestSecurityContextHolder.clearContext();
+	}
+
+
+	@Test
+	public void oauth2ClientWhenUsingDefaultsThenException()
+			throws Exception {
+
+		assertThatCode(() -> oauth2Client().postProcessRequest(new MockHttpServletRequest()))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessageContaining("ClientRegistration");
+	}
+
+	@Test
+	public void oauth2ClientWhenUsingDefaultsThenProducesDefaultAuthorizedClient()
+		throws Exception {
+
+		this.mvc.perform(get("/access-token").with(oauth2Client("registration-id")))
+				.andExpect(content().string("access-token"));
+		this.mvc.perform(get("/client-id").with(oauth2Client("registration-id")))
+				.andExpect(content().string("test-client"));
+	}
+
+	@Test
+	public void oauth2ClientWhenClientRegistrationThenUses()
+			throws Exception {
+
+		ClientRegistration clientRegistration = clientRegistration()
+				.registrationId("registration-id").clientId("client-id").build();
+		this.mvc.perform(get("/client-id")
+				.with(oauth2Client().clientRegistration(clientRegistration)))
+				.andExpect(content().string("client-id"));
+	}
+
+	@Test
+	public void oauth2ClientWhenClientRegistrationConsumerThenUses()
+			throws Exception {
+
+		this.mvc.perform(get("/client-id")
+				.with(oauth2Client("registration-id").clientRegistration(c -> c.clientId("client-id"))))
+				.andExpect(content().string("client-id"));
+	}
+
+	@Test
+	public void oauth2ClientWhenAccessTokenThenUses() throws Exception {
+		OAuth2AccessToken accessToken = noScopes();
+		this.mvc.perform(get("/access-token")
+				.with(oauth2Client("registration-id").accessToken(accessToken)))
+				.andExpect(content().string("no-scopes"));
+	}
+
+	@EnableWebSecurity
+	@EnableWebMvc
+	static class OAuth2ClientConfig extends WebSecurityConfigurerAdapter {
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			http
+				.authorizeRequests(authz -> authz
+					.anyRequest().permitAll()
+				)
+				.oauth2Client();
+		}
+
+		@Bean
+		ClientRegistrationRepository clientRegistrationRepository() {
+			return mock(ClientRegistrationRepository.class);
+		}
+
+
+		@Bean
+		OAuth2AuthorizedClientRepository authorizedClientRepository() {
+			return new HttpSessionOAuth2AuthorizedClientRepository();
+		}
+
+		@RestController
+		static class PrincipalController {
+			@GetMapping("/access-token")
+			String accessToken(@RegisteredOAuth2AuthorizedClient("registration-id") OAuth2AuthorizedClient authorizedClient) {
+				return authorizedClient.getAccessToken().getTokenValue();
+			}
+
+			@GetMapping("/client-id")
+			String clientId(@RegisteredOAuth2AuthorizedClient("registration-id") OAuth2AuthorizedClient authorizedClient) {
+				return authorizedClient.getClientRegistration().getClientId();
+			}
+		}
+	}
+}