Browse Source

Add oauth2Client WebTestClient Support

Fixes gh-7910
Josh Cummings 5 years ago
parent
commit
ffb5a3a0d4

+ 159 - 26
test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java

@@ -60,6 +60,7 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
 import org.springframework.security.oauth2.server.resource.authentication.JwtGrantedAuthoritiesConverter;
+import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors;
 import org.springframework.security.web.server.csrf.CsrfWebFilter;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.test.web.reactive.server.MockServerConfigurer;
@@ -182,6 +183,39 @@ public class SecurityMockServerConfigurers {
 		return new OidcLoginMutator(accessToken);
 	}
 
+	/**
+	 * Updates the ServerWebExchange to establish a {@link OAuth2AuthorizedClient} in the session.
+	 * All details are declarative and do not require the corresponding OAuth 2.0 tokens to be valid.
+	 *
+	 * <p>
+	 * 	The support works by associating the authorized client to the ServerWebExchange
+	 * 	via the {@link WebSessionServerOAuth2AuthorizedClientRepository}
+	 * </p>
+	 *
+	 * @return the {@link OAuth2ClientMutator} to further configure or use
+	 * @since 5.3
+	 */
+	public static OAuth2ClientMutator mockOAuth2Client() {
+		return new OAuth2ClientMutator();
+	}
+
+	/**
+	 * Updates the ServerWebExchange to establish a {@link OAuth2AuthorizedClient} in the session.
+	 * All details are declarative and do not require the corresponding OAuth 2.0 tokens to be valid.
+	 *
+	 * <p>
+	 * 	The support works by associating the authorized client to the ServerWebExchange
+	 * 	via the {@link WebSessionServerOAuth2AuthorizedClientRepository}
+	 * </p>
+	 *
+	 * @param registrationId The registration id associated with the {@link OAuth2AuthorizedClient}
+	 * @return the {@link OAuth2ClientMutator} to further configure or use
+	 * @since 5.3
+	 */
+	public static OAuth2ClientMutator mockOAuth2Client(String registrationId) {
+		return new OAuth2ClientMutator(registrationId);
+	}
+
 	public static CsrfMutator csrf() {
 		return new CsrfMutator();
 	}
@@ -591,12 +625,19 @@ public class SecurityMockServerConfigurers {
 		@Override
 		public void beforeServerCreated(WebHttpHandlerBuilder builder) {
 			OAuth2AuthenticationToken token = getToken();
-			builder.filters(addAuthorizedClientFilter(token));
+			mockOAuth2Client()
+					.accessToken(this.accessToken)
+					.clientRegistration(this.clientRegistration)
+					.beforeServerCreated(builder);
 			mockAuthentication(getToken()).beforeServerCreated(builder);
 		}
 
 		@Override
 		public void afterConfigureAdded(WebTestClient.MockServerSpec<?> serverSpec) {
+			mockOAuth2Client()
+					.accessToken(this.accessToken)
+					.clientRegistration(this.clientRegistration)
+					.afterConfigureAdded(serverSpec);
 			mockAuthentication(getToken()).afterConfigureAdded(serverSpec);
 		}
 
@@ -606,26 +647,18 @@ public class SecurityMockServerConfigurers {
 				@Nullable WebHttpHandlerBuilder httpHandlerBuilder,
 				@Nullable ClientHttpConnector connector) {
 			OAuth2AuthenticationToken token = getToken();
-			httpHandlerBuilder.filters(addAuthorizedClientFilter(token));
+			mockOAuth2Client()
+					.accessToken(this.accessToken)
+					.clientRegistration(this.clientRegistration)
+					.afterConfigurerAdded(builder, httpHandlerBuilder, connector);
 			mockAuthentication(token).afterConfigurerAdded(builder, httpHandlerBuilder, connector);
 		}
 
-		private Consumer<List<WebFilter>> addAuthorizedClientFilter(OAuth2AuthenticationToken token) {
-			OAuth2AuthorizedClient client = getClient();
-			return filters -> filters.add(0, (exchange, chain) ->
-					this.authorizedClientRepository.saveAuthorizedClient(client, token, exchange)
-							.then(chain.filter(exchange)));
-		}
-
 		private OAuth2AuthenticationToken getToken() {
 			OAuth2User oauth2User = this.oauth2User.get();
 			return new OAuth2AuthenticationToken(oauth2User, oauth2User.getAuthorities(), this.clientRegistration.getRegistrationId());
 		}
 
-		private OAuth2AuthorizedClient getClient() {
-			return new OAuth2AuthorizedClient(this.clientRegistration, getToken().getName(), this.accessToken);
-		}
-
 		private ClientRegistration.Builder clientRegistrationBuilder() {
 			return ClientRegistration.withRegistrationId("test")
 					.authorizationGrantType(AuthorizationGrantType.PASSWORD)
@@ -760,12 +793,19 @@ public class SecurityMockServerConfigurers {
 		@Override
 		public void beforeServerCreated(WebHttpHandlerBuilder builder) {
 			OAuth2AuthenticationToken token = getToken();
-			builder.filters(addAuthorizedClientFilter(token));
+			mockOAuth2Client()
+					.accessToken(this.accessToken)
+					.clientRegistration(this.clientRegistration)
+					.beforeServerCreated(builder);
 			mockAuthentication(getToken()).beforeServerCreated(builder);
 		}
 
 		@Override
 		public void afterConfigureAdded(WebTestClient.MockServerSpec<?> serverSpec) {
+			mockOAuth2Client()
+					.accessToken(this.accessToken)
+					.clientRegistration(this.clientRegistration)
+					.afterConfigureAdded(serverSpec);
 			mockAuthentication(getToken()).afterConfigureAdded(serverSpec);
 		}
 
@@ -775,17 +815,13 @@ public class SecurityMockServerConfigurers {
 				@Nullable WebHttpHandlerBuilder httpHandlerBuilder,
 				@Nullable ClientHttpConnector connector) {
 			OAuth2AuthenticationToken token = getToken();
-			httpHandlerBuilder.filters(addAuthorizedClientFilter(token));
+			mockOAuth2Client()
+					.accessToken(this.accessToken)
+					.clientRegistration(this.clientRegistration)
+					.afterConfigurerAdded(builder, httpHandlerBuilder, connector);
 			mockAuthentication(token).afterConfigurerAdded(builder, httpHandlerBuilder, connector);
 		}
 
-		private Consumer<List<WebFilter>> addAuthorizedClientFilter(OAuth2AuthenticationToken token) {
-			OAuth2AuthorizedClient client = getClient();
-			return filters -> filters.add(0, (exchange, chain) ->
-					authorizedClientRepository.saveAuthorizedClient(client, token, exchange)
-							.then(chain.filter(exchange)));
-		}
-
 		private ClientRegistration.Builder clientRegistrationBuilder() {
 			return ClientRegistration.withRegistrationId("test")
 					.authorizationGrantType(AuthorizationGrantType.PASSWORD)
@@ -798,10 +834,6 @@ public class SecurityMockServerConfigurers {
 			return new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), this.clientRegistration.getRegistrationId());
 		}
 
-		private OAuth2AuthorizedClient getClient() {
-			return new OAuth2AuthorizedClient(this.clientRegistration, getToken().getName(), this.accessToken);
-		}
-
 		private Collection<GrantedAuthority> getAuthorities() {
 			if (this.authorities == null) {
 				Set<GrantedAuthority> authorities = new LinkedHashSet<>();
@@ -831,4 +863,105 @@ public class SecurityMockServerConfigurers {
 			return new DefaultOidcUser(getAuthorities(), getOidcIdToken(), this.userInfo);
 		}
 	}
+
+	/**
+	 * @author Josh Cummings
+	 * @since 5.3
+	 */
+	public final static class OAuth2ClientMutator implements WebTestClientConfigurer, MockServerConfigurer {
+		private String registrationId = "test";
+		private ClientRegistration clientRegistration;
+		private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				"access-token", null, null, Collections.singleton("user"));
+
+		private ServerOAuth2AuthorizedClientRepository authorizedClientRepository =
+				new WebSessionServerOAuth2AuthorizedClientRepository();
+
+		private OAuth2ClientMutator() {
+		}
+
+		private OAuth2ClientMutator(String registrationId) {
+			this.registrationId = registrationId;
+			clientRegistration(c -> {});
+		}
+
+		/**
+		 * Use this {@link ClientRegistration}
+		 *
+		 * @param clientRegistration
+		 * @return the {@link SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor} for further configuration
+		 */
+		public OAuth2ClientMutator clientRegistration(ClientRegistration clientRegistration) {
+			this.clientRegistration = clientRegistration;
+			return this;
+		}
+
+		/**
+		 * Use this {@link Consumer} to configure a {@link ClientRegistration}
+		 *
+		 * @param clientRegistrationConfigurer the {@link ClientRegistration} configurer
+		 * @return the {@link SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor} for further configuration
+		 */
+		public OAuth2ClientMutator clientRegistration
+				(Consumer<ClientRegistration.Builder> clientRegistrationConfigurer) {
+
+			ClientRegistration.Builder builder = clientRegistrationBuilder();
+			clientRegistrationConfigurer.accept(builder);
+			this.clientRegistration = builder.build();
+			return this;
+		}
+
+		/**
+		 * Use this {@link OAuth2AccessToken}
+		 *
+		 * @param accessToken the {@link OAuth2AccessToken} to use
+		 * @return the {@link SecurityMockMvcRequestPostProcessors.OAuth2ClientRequestPostProcessor} for further configuration
+		 */
+		public OAuth2ClientMutator accessToken(OAuth2AccessToken accessToken) {
+			this.accessToken = accessToken;
+			return this;
+		}
+
+
+		@Override
+		public void beforeServerCreated(WebHttpHandlerBuilder builder) {
+			builder.filters(addAuthorizedClientFilter());
+		}
+
+		@Override
+		public void afterConfigureAdded(WebTestClient.MockServerSpec<?> serverSpec) {
+
+		}
+
+		@Override
+		public void afterConfigurerAdded(
+				WebTestClient.Builder builder,
+				@Nullable WebHttpHandlerBuilder httpHandlerBuilder,
+				@Nullable ClientHttpConnector connector) {
+			httpHandlerBuilder.filters(addAuthorizedClientFilter());
+		}
+
+		private Consumer<List<WebFilter>> addAuthorizedClientFilter() {
+			OAuth2AuthorizedClient client = getClient();
+			return filters -> filters.add(0, (exchange, chain) ->
+					authorizedClientRepository.saveAuthorizedClient(client, null, exchange)
+							.then(chain.filter(exchange)));
+		}
+
+		private OAuth2AuthorizedClient getClient() {
+			if (this.clientRegistration == null) {
+				throw new IllegalArgumentException("Please specify a ClientRegistration via one " +
+						"of the clientRegistration methods");
+			}
+			return new OAuth2AuthorizedClient(this.clientRegistration, "test-subject", this.accessToken);
+		}
+
+		private ClientRegistration.Builder clientRegistrationBuilder() {
+			return ClientRegistration.withRegistrationId(this.registrationId)
+					.authorizationGrantType(AuthorizationGrantType.PASSWORD)
+					.clientId("test-client")
+					.clientSecret("test-secret")
+					.tokenUri("https://idp.example.org/oauth/token");
+		}
+	}
 }

+ 166 - 0
test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOAuth2ClientTests.java

@@ -0,0 +1,166 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.test.web.reactive.server;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
+import org.springframework.test.web.reactive.server.WebTestClient;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.reactive.DispatcherHandler;
+import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
+import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes;
+import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockOAuth2Client;
+import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity;
+
+@RunWith(MockitoJUnitRunner.class)
+public class SecurityMockServerConfigurersOAuth2ClientTests extends AbstractMockServerConfigurersTests {
+	private OAuth2LoginController controller = new OAuth2LoginController();
+
+	@Mock
+	private ReactiveClientRegistrationRepository clientRegistrationRepository;
+
+	private WebTestClient client;
+
+	@Before
+	public void setup() {
+		ServerOAuth2AuthorizedClientRepository authorizedClientRepository =
+				new WebSessionServerOAuth2AuthorizedClientRepository();
+
+		this.client = WebTestClient
+				.bindToController(this.controller)
+				.argumentResolvers(c -> c.addCustomResolver(
+						new OAuth2AuthorizedClientArgumentResolver
+								(this.clientRegistrationRepository, authorizedClientRepository)))
+				.webFilter(new SecurityContextServerWebExchangeWebFilter())
+				.apply(springSecurity())
+				.configureClient()
+				.defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
+				.build();
+	}
+
+	@Test
+	public void oauth2ClientWhenUsingDefaultsThenException()
+			throws Exception {
+
+		WebHttpHandlerBuilder builder = WebHttpHandlerBuilder.webHandler(new DispatcherHandler());
+		assertThatCode(() -> mockOAuth2Client().beforeServerCreated(builder))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessageContaining("ClientRegistration");
+	}
+
+	@Test
+	public void oauth2ClientWhenUsingRegistrationIdThenProducesAuthorizedClient()
+			throws Exception {
+
+		this.client.mutateWith(mockOAuth2Client("registration-id"))
+				.get().uri("/client")
+				.exchange()
+				.expectStatus().isOk();
+
+		OAuth2AuthorizedClient client = this.controller.authorizedClient;
+		assertThat(client).isNotNull();
+		assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("registration-id");
+		assertThat(client.getAccessToken().getTokenValue()).isEqualTo("access-token");
+		assertThat(client.getRefreshToken()).isNull();
+	}
+
+	@Test
+	public void oauth2ClientWhenClientRegistrationThenUses()
+			throws Exception {
+
+		ClientRegistration clientRegistration = clientRegistration()
+				.registrationId("registration-id").clientId("client-id").build();
+		this.client.mutateWith(mockOAuth2Client().clientRegistration(clientRegistration))
+				.get().uri("/client")
+				.exchange()
+				.expectStatus().isOk();
+
+		OAuth2AuthorizedClient client = this.controller.authorizedClient;
+		assertThat(client).isNotNull();
+		assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("registration-id");
+		assertThat(client.getAccessToken().getTokenValue()).isEqualTo("access-token");
+		assertThat(client.getRefreshToken()).isNull();
+	}
+
+	@Test
+	public void oauth2ClientWhenClientRegistrationConsumerThenUses()
+			throws Exception {
+
+		this.client.mutateWith(mockOAuth2Client("registration-id")
+				.clientRegistration(c -> c.clientId("client-id")))
+				.get().uri("/client")
+				.exchange()
+				.expectStatus().isOk();
+
+		OAuth2AuthorizedClient client = this.controller.authorizedClient;
+		assertThat(client).isNotNull();
+		assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("registration-id");
+		assertThat(client.getClientRegistration().getClientId()).isEqualTo("client-id");
+		assertThat(client.getAccessToken().getTokenValue()).isEqualTo("access-token");
+		assertThat(client.getRefreshToken()).isNull();
+	}
+
+	@Test
+	public void oauth2ClientWhenAccessTokenThenUses()
+			throws Exception {
+
+		OAuth2AccessToken accessToken = noScopes();
+		this.client.mutateWith(mockOAuth2Client("registration-id")
+				.accessToken(accessToken))
+				.get().uri("/client")
+				.exchange()
+				.expectStatus().isOk();
+
+		OAuth2AuthorizedClient client = this.controller.authorizedClient;
+		assertThat(client).isNotNull();
+		assertThat(client.getClientRegistration().getRegistrationId()).isEqualTo("registration-id");
+		assertThat(client.getAccessToken().getTokenValue()).isEqualTo("no-scopes");
+		assertThat(client.getRefreshToken()).isNull();
+	}
+
+	@RestController
+	static class OAuth2LoginController {
+		volatile OAuth2AuthorizedClient authorizedClient;
+
+		@GetMapping("/client")
+		String authorizedClient
+				(@RegisteredOAuth2AuthorizedClient("registration-id") OAuth2AuthorizedClient authorizedClient) {
+			this.authorizedClient = authorizedClient;
+			return authorizedClient.getPrincipalName();
+		}
+	}
+}