Explorar o código

Add new configuration options for OAuth2LoginSpec

Fixes gh-5598
Nick Bromfield %!s(int64=6) %!d(string=hai) anos
pai
achega
b581bb7eae

+ 50 - 3
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -53,8 +53,10 @@ import org.springframework.security.oauth2.client.web.server.AuthenticatedPrinci
 import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
 import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.jwt.Jwt;
@@ -588,6 +590,10 @@ public class ServerHttpSecurity {
 
 		private ServerAuthenticationConverter authenticationConverter;
 
+		private ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver;
+
+		private ServerWebExchangeMatcher authenticationMatcher;
+
 		/**
 		 * Configures the {@link ReactiveAuthenticationManager} to use. The default is
 		 * {@link OAuth2AuthorizationCodeReactiveAuthenticationManager}
@@ -664,6 +670,37 @@ public class ServerHttpSecurity {
 			return this;
 		}
 
+		/**
+		 * Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s.
+		 *
+		 * @since 5.2
+		 * @param authorizationRequestResolver the resolver used for resolving {@link OAuth2AuthorizationRequest}'s
+		 * @return the {@link OAuth2LoginSpec} for further configuration
+		 */
+		public OAuth2LoginSpec authorizationRequestResolver(ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver) {
+			this.authorizationRequestResolver = authorizationRequestResolver;
+			return this;
+		}
+
+		/**
+		 * Sets the {@link ServerWebExchangeMatcher matcher} used for determining if the request is an authentication request.
+		 *
+		 * @since 5.2
+		 * @param authenticationMatcher the {@link ServerWebExchangeMatcher matcher} used for determining if the request is an authentication request
+		 * @return the {@link OAuth2LoginSpec} for further configuration
+		 */
+		public OAuth2LoginSpec authenticationMatcher(ServerWebExchangeMatcher authenticationMatcher) {
+			this.authenticationMatcher = authenticationMatcher;
+			return this;
+		}
+
+		private ServerWebExchangeMatcher getAuthenticationMatcher() {
+			if (this.authenticationMatcher == null) {
+				this.authenticationMatcher = createAttemptAuthenticationRequestMatcher();
+			}
+			return this.authenticationMatcher;
+		}
+
 		/**
 		 * Allows method chaining to continue configuring the {@link ServerHttpSecurity}
 		 * @return the {@link ServerHttpSecurity} to continue configuring
@@ -676,12 +713,12 @@ public class ServerHttpSecurity {
 		protected void configure(ServerHttpSecurity http) {
 			ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository();
 			ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository();
-			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(clientRegistrationRepository);
+			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter();
 
 			ReactiveAuthenticationManager manager = getAuthenticationManager();
 
 			AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, authorizedClientRepository);
-			authenticationFilter.setRequiresAuthenticationMatcher(createAttemptAuthenticationRequestMatcher());
+			authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher());
 			authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));
 			RedirectServerAuthenticationSuccessHandler redirectHandler = new RedirectServerAuthenticationSuccessHandler();
 
@@ -756,6 +793,16 @@ public class ServerHttpSecurity {
 			return this.clientRegistrationRepository;
 		}
 
+		private OAuth2AuthorizationRequestRedirectWebFilter getRedirectWebFilter() {
+			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter;
+			if (this.authorizationRequestResolver == null) {
+				oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(getClientRegistrationRepository());
+			} else {
+				oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(this.authorizationRequestResolver);
+			}
+			return oauthRedirectFilter;
+		}
+
 		private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() {
 			ServerOAuth2AuthorizedClientRepository result = this.authorizedClientRepository;
 			if (result == null) {

+ 19 - 5
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -37,6 +37,7 @@ import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@@ -59,6 +60,7 @@ import org.springframework.security.test.web.reactive.server.WebTestClientBuilde
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.WebFilterChainProxy;
 import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
+import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
@@ -100,7 +102,7 @@ public class OAuth2LoginTests {
 
 	@Test
 	public void defaultLoginPageWithMultipleClientRegistrationsThenLinks() {
-		this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class).autowire();
+		this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class).autowire();
 
 		WebTestClient webTestClient = WebTestClientBuilder
 				.bindToWebFilters(this.springSecurity)
@@ -120,7 +122,7 @@ public class OAuth2LoginTests {
 	}
 
 	@EnableWebFluxSecurity
-	static class OAuth2LoginWithMulitpleClientRegistrations {
+	static class OAuth2LoginWithMultipleClientRegistrations {
 		@Bean
 		InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() {
 			return new InMemoryReactiveClientRegistrationRepository(github, google);
@@ -165,6 +167,8 @@ public class OAuth2LoginTests {
 				.getBean(OAuth2LoginMockAuthenticationManagerConfig.class);
 		ServerAuthenticationConverter converter = config.authenticationConverter;
 		ReactiveAuthenticationManager manager = config.manager;
+		ServerWebExchangeMatcher matcher = config.matcher;
+		ServerOAuth2AuthorizationRequestResolver resolver = config.resolver;
 
 		OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success();
 		OAuth2User user = TestOAuth2Users.create();
@@ -174,6 +178,8 @@ public class OAuth2LoginTests {
 
 		when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c")));
 		when(manager.authenticate(any())).thenReturn(Mono.just(result));
+		when(matcher.matches(any())).thenReturn(ServerWebExchangeMatcher.MatchResult.match());
+		when(resolver.resolve(any())).thenReturn(Mono.empty());
 
 		webTestClient.get()
 			.uri("/login/oauth2/code/github")
@@ -182,6 +188,8 @@ public class OAuth2LoginTests {
 
 		verify(converter).convert(any());
 		verify(manager).authenticate(any());
+		verify(matcher).matches(any());
+		verify(resolver).resolve(any());
 	}
 
 	@Configuration
@@ -190,6 +198,10 @@ public class OAuth2LoginTests {
 
 		ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class);
 
+		ServerWebExchangeMatcher matcher = mock(ServerWebExchangeMatcher.class);
+
+		ServerOAuth2AuthorizationRequestResolver resolver = mock(ServerOAuth2AuthorizationRequestResolver.class);
+
 		@Bean
 		public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
 			http
@@ -198,14 +210,16 @@ public class OAuth2LoginTests {
 					.and()
 				.oauth2Login()
 					.authenticationConverter(authenticationConverter)
-					.authenticationManager(manager);
+					.authenticationManager(manager)
+					.authenticationMatcher(matcher)
+					.authorizationRequestResolver(resolver);
 			return http.build();
 		}
 	}
 
 	@Test
 	public void oauth2LoginWhenCustomJwtDecoderFactoryThenUsed() {
-		this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class,
+		this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class,
 				OAuth2LoginWithJwtDecoderFactoryBeanConfig.class).autowire();
 
 		WebTestClient webTestClient = WebTestClientBuilder