瀏覽代碼

OAuth2AuthorizationCodeGrantWebFilter matches on query parameters

Fixes gh-7966
Joe Grandja 5 年之前
父節點
當前提交
0809c04aa2

+ 40 - 19
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -37,13 +37,20 @@ import org.springframework.security.web.server.authentication.ServerAuthenticati
 import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.util.Assert;
-import org.springframework.util.MultiValueMap;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilterChain;
+import org.springframework.web.util.UriComponents;
 import org.springframework.web.util.UriComponentsBuilder;
 import reactor.core.publisher.Mono;
 
+import java.net.URI;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+
 /**
  * A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
  * which handles the processing of the OAuth 2.0 Authorization Response.
@@ -165,10 +172,10 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
 	@Override
 	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 		return this.requiresAuthenticationMatcher.matches(exchange)
-				.filter( matchResult -> matchResult.isMatch())
-				.flatMap( matchResult -> this.authenticationConverter.convert(exchange))
+				.filter(ServerWebExchangeMatcher.MatchResult::isMatch)
+				.flatMap(matchResult -> this.authenticationConverter.convert(exchange))
 				.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
-				.flatMap( token -> authenticate(exchange, chain, token));
+				.flatMap(token -> authenticate(exchange, chain, token));
 	}
 
 	private Mono<Void> authenticate(ServerWebExchange exchange,
@@ -198,20 +205,34 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
 	}
 
 	private Mono<ServerWebExchangeMatcher.MatchResult> matchesAuthorizationResponse(ServerWebExchange exchange) {
-		return this.authorizationRequestRepository.loadAuthorizationRequest(exchange)
-				.flatMap(authorizationRequest -> {
-					String requestUrl = UriComponentsBuilder.fromUri(exchange.getRequest().getURI())
-							.query(null)
-							.build()
-							.toUriString();
-					MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();
-					if (requestUrl.equals(authorizationRequest.getRedirectUri()) &&
-							OAuth2AuthorizationResponseUtils.isAuthorizationResponse(queryParams)) {
-						return ServerWebExchangeMatcher.MatchResult.match();
-					}
-					return ServerWebExchangeMatcher.MatchResult.notMatch();
-				})
-				.filter(ServerWebExchangeMatcher.MatchResult::isMatch)
+		return Mono.just(exchange)
+				.filter(exch -> OAuth2AuthorizationResponseUtils.isAuthorizationResponse(exch.getRequest().getQueryParams()))
+				.flatMap(exch -> this.authorizationRequestRepository.loadAuthorizationRequest(exchange)
+						.flatMap(authorizationRequest ->
+								matchesRedirectUri(exch.getRequest().getURI(), authorizationRequest.getRedirectUri())))
 				.switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch());
 	}
+
+	private static Mono<ServerWebExchangeMatcher.MatchResult> matchesRedirectUri(
+			URI authorizationResponseUri, String authorizationRequestRedirectUri) {
+		UriComponents requestUri = UriComponentsBuilder.fromUri(authorizationResponseUri).build();
+		UriComponents redirectUri = UriComponentsBuilder.fromUriString(authorizationRequestRedirectUri).build();
+		Set<Map.Entry<String, List<String>>> requestUriParameters =
+				new LinkedHashSet<>(requestUri.getQueryParams().entrySet());
+		Set<Map.Entry<String, List<String>>> redirectUriParameters =
+				new LinkedHashSet<>(redirectUri.getQueryParams().entrySet());
+		// Remove the additional request parameters (if any) from the authorization response (request)
+		// before doing an exact comparison with the authorizationRequest.getRedirectUri() parameters (if any)
+		requestUriParameters.retainAll(redirectUriParameters);
+
+		if (Objects.equals(requestUri.getScheme(), redirectUri.getScheme()) &&
+				Objects.equals(requestUri.getUserInfo(), redirectUri.getUserInfo()) &&
+				Objects.equals(requestUri.getHost(), redirectUri.getHost()) &&
+				Objects.equals(requestUri.getPort(), redirectUri.getPort()) &&
+				Objects.equals(requestUri.getPath(), redirectUri.getPath()) &&
+				Objects.equals(requestUriParameters.toString(), redirectUriParameters.toString())) {
+			return ServerWebExchangeMatcher.MatchResult.match();
+		}
+		return ServerWebExchangeMatcher.MatchResult.notMatch();
+	}
 }

+ 2 - 7
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -28,7 +28,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
 import org.springframework.util.Assert;
-import org.springframework.util.MultiValueMap;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.util.UriComponentsBuilder;
 import reactor.core.publisher.Mono;
@@ -103,14 +102,10 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverter
 	}
 
 	private static OAuth2AuthorizationResponse convertResponse(ServerWebExchange exchange) {
-		MultiValueMap<String, String> queryParams = exchange.getRequest()
-				.getQueryParams();
 		String redirectUri = UriComponentsBuilder.fromUri(exchange.getRequest().getURI())
-				.query(null)
 				.build()
 				.toUriString();
-
 		return OAuth2AuthorizationResponseUtils
-				.convert(queryParams, redirectUri);
+				.convert(exchange.getRequest().getQueryParams(), redirectUri);
 	}
 }

+ 161 - 40
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -25,25 +25,28 @@ import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.web.server.MockServerWebExchange;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.authentication.ReactiveAuthenticationManager;
-import org.springframework.security.core.Authentication;
-import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.TestOAuth2AuthorizationCodeAuthenticationTokens;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
-import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
-import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
+import org.springframework.util.CollectionUtils;
 import org.springframework.web.server.handler.DefaultWebFilterChain;
 import reactor.core.publisher.Mono;
 
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
 import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
+import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
 
 /**
  * @author Rob Winch
@@ -102,52 +105,170 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
 		MockServerWebExchange exchange = MockServerWebExchange
 				.from(MockServerHttpRequest.get("/"));
 		DefaultWebFilterChain chain = new DefaultWebFilterChain(
-				e -> e.getResponse().setComplete());
+				e -> e.getResponse().setComplete(), Collections.emptyList());
 
 		this.filter.filter(exchange, chain).block();
 
-		verifyZeroInteractions(this.authenticationManager);
+		verifyNoInteractions(this.authenticationManager);
 	}
 
 	@Test
 	public void filterWhenMatchThenAuthorizedClientSaved() {
-		OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
-				.redirectUri("/authorize/registration-id")
-				.build();
-		OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success()
-				.redirectUri("/authorize/registration-id")
-				.build();
-		OAuth2AuthorizationExchange authorizationExchange =
-				new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse);
-		ClientRegistration registration = TestClientRegistrations.clientRegistration().build();
-		Mono<Authentication> authentication = Mono.just(
-				new OAuth2AuthorizationCodeAuthenticationToken(registration, authorizationExchange));
-		OAuth2AuthorizationCodeAuthenticationToken authenticated = TestOAuth2AuthorizationCodeAuthenticationTokens
-				.authenticated();
-
-		when(this.authenticationManager.authenticate(any())).thenReturn(
-				Mono.just(authenticated));
+		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
+		when(this.clientRegistrationRepository.findByRegistrationId(any()))
+				.thenReturn(Mono.just(clientRegistration));
+
+		MockServerHttpRequest authorizationRequest =
+				createAuthorizationRequest("/authorization/callback");
+		OAuth2AuthorizationRequest oauth2AuthorizationRequest =
+				createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
 		when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
-				.thenReturn(Mono.just(authorizationRequest));
+				.thenReturn(Mono.just(oauth2AuthorizationRequest));
+		when(this.authorizationRequestRepository.removeAuthorizationRequest(any()))
+				.thenReturn(Mono.just(oauth2AuthorizationRequest));
+
 		when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any()))
 				.thenReturn(Mono.empty());
-		ServerAuthenticationConverter converter = e -> authentication;
-
-		this.filter = new OAuth2AuthorizationCodeGrantWebFilter(
-				this.authenticationManager, converter, this.authorizedClientRepository);
-		this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
+		when(this.authenticationManager.authenticate(any()))
+				.thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated()));
 
-		MockServerHttpRequest request = MockServerHttpRequest
-				.get("/authorize/registration-id")
-				.queryParam(OAuth2ParameterNames.CODE, "code")
-				.queryParam(OAuth2ParameterNames.STATE, "state")
-				.build();
-		MockServerWebExchange exchange = MockServerWebExchange.from(request);
+		MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
+		MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
 		DefaultWebFilterChain chain = new DefaultWebFilterChain(
-				e -> e.getResponse().setComplete());
+				e -> e.getResponse().setComplete(), Collections.emptyList());
 
 		this.filter.filter(exchange, chain).block();
 
 		verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(AnonymousAuthenticationToken.class), any());
 	}
+
+	// gh-7966
+	@Test
+	public void filterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed() {
+		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
+		when(this.clientRegistrationRepository.findByRegistrationId(any()))
+				.thenReturn(Mono.just(clientRegistration));
+		when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any()))
+				.thenReturn(Mono.empty());
+		when(this.authenticationManager.authenticate(any()))
+				.thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated()));
+
+		// 1) redirect_uri with query parameters
+		Map<String, String> parameters = new LinkedHashMap<>();
+		parameters.put("param1", "value1");
+		parameters.put("param2", "value2");
+		MockServerHttpRequest authorizationRequest =
+				createAuthorizationRequest("/authorization/callback", parameters);
+		OAuth2AuthorizationRequest oauth2AuthorizationRequest =
+				createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
+		when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
+				.thenReturn(Mono.just(oauth2AuthorizationRequest));
+		when(this.authorizationRequestRepository.removeAuthorizationRequest(any()))
+				.thenReturn(Mono.just(oauth2AuthorizationRequest));
+
+		MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
+		MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
+		DefaultWebFilterChain chain = new DefaultWebFilterChain(
+				e -> e.getResponse().setComplete(), Collections.emptyList());
+
+		this.filter.filter(exchange, chain).block();
+		verify(this.authenticationManager, times(1)).authenticate(any());
+
+		// 2) redirect_uri with query parameters AND authorization response additional parameters
+		Map<String, String> additionalParameters = new LinkedHashMap<>();
+		additionalParameters.put("auth-param1", "value1");
+		additionalParameters.put("auth-param2", "value2");
+		authorizationResponse = createAuthorizationResponse(authorizationRequest, additionalParameters);
+		exchange = MockServerWebExchange.from(authorizationResponse);
+
+		this.filter.filter(exchange, chain).block();
+		verify(this.authenticationManager, times(2)).authenticate(any());
+	}
+
+	// gh-7966
+	@Test
+	public void filterWhenAuthorizationRequestRedirectUriParametersNotMatchThenNotProcessed() {
+		String requestUri = "/authorization/callback";
+		Map<String, String> parameters = new LinkedHashMap<>();
+		parameters.put("param1", "value1");
+		parameters.put("param2", "value2");
+		MockServerHttpRequest authorizationRequest =
+				createAuthorizationRequest(requestUri, parameters);
+		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
+		OAuth2AuthorizationRequest oauth2AuthorizationRequest =
+				createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
+		when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
+				.thenReturn(Mono.just(oauth2AuthorizationRequest));
+
+		// 1) Parameter value
+		Map<String, String> parametersNotMatch = new LinkedHashMap<>(parameters);
+		parametersNotMatch.put("param2", "value8");
+		MockServerHttpRequest authorizationResponse = createAuthorizationResponse(
+				createAuthorizationRequest(requestUri, parametersNotMatch));
+		MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
+		DefaultWebFilterChain chain = new DefaultWebFilterChain(
+				e -> e.getResponse().setComplete(), Collections.emptyList());
+
+		this.filter.filter(exchange, chain).block();
+		verifyNoInteractions(this.authenticationManager);
+
+		// 2) Parameter order
+		parametersNotMatch = new LinkedHashMap<>();
+		parametersNotMatch.put("param2", "value2");
+		parametersNotMatch.put("param1", "value1");
+		authorizationResponse = createAuthorizationResponse(
+				createAuthorizationRequest(requestUri, parametersNotMatch));
+		exchange = MockServerWebExchange.from(authorizationResponse);
+
+		this.filter.filter(exchange, chain).block();
+		verifyNoInteractions(this.authenticationManager);
+
+		// 3) Parameter missing
+		parametersNotMatch = new LinkedHashMap<>(parameters);
+		parametersNotMatch.remove("param2");
+		authorizationResponse = createAuthorizationResponse(
+				createAuthorizationRequest(requestUri, parametersNotMatch));
+		exchange = MockServerWebExchange.from(authorizationResponse);
+
+		this.filter.filter(exchange, chain).block();
+		verifyNoInteractions(this.authenticationManager);
+	}
+
+	private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(
+			MockServerHttpRequest authorizationRequest, ClientRegistration registration) {
+		Map<String, Object> attributes = new HashMap<>();
+		attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
+		return request()
+				.attributes(attributes)
+				.redirectUri(authorizationRequest.getURI().toString())
+				.build();
+	}
+
+	private static MockServerHttpRequest createAuthorizationRequest(String requestUri) {
+		return createAuthorizationRequest(requestUri, new LinkedHashMap<>());
+	}
+
+	private static MockServerHttpRequest createAuthorizationRequest(String requestUri, Map<String, String> parameters) {
+		MockServerHttpRequest.BaseBuilder<?> builder = MockServerHttpRequest
+				.get(requestUri);
+		if (!CollectionUtils.isEmpty(parameters)) {
+			parameters.forEach(builder::queryParam);
+		}
+		return builder.build();
+	}
+
+	private static MockServerHttpRequest createAuthorizationResponse(MockServerHttpRequest authorizationRequest) {
+		return createAuthorizationResponse(authorizationRequest, new LinkedHashMap<>());
+	}
+
+	private static MockServerHttpRequest createAuthorizationResponse(
+			MockServerHttpRequest authorizationRequest, Map<String, String> additionalParameters) {
+		MockServerHttpRequest.BaseBuilder<?> builder = MockServerHttpRequest
+				.get(authorizationRequest.getURI().toString());
+		builder.queryParam(OAuth2ParameterNames.CODE, "code");
+		builder.queryParam(OAuth2ParameterNames.STATE, "state");
+		additionalParameters.forEach(builder::queryParam);
+		builder.cookies(authorizationRequest.getCookies());
+		return builder.build();
+	}
 }