|
@@ -1,5 +1,5 @@
|
|
|
/*
|
|
|
- * Copyright 2002-2019 the original author or authors.
|
|
|
+ * Copyright 2002-2020 the original author or authors.
|
|
|
*
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -25,25 +25,28 @@ import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
|
|
|
import org.springframework.mock.web.server.MockServerWebExchange;
|
|
|
import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
|
|
import org.springframework.security.authentication.ReactiveAuthenticationManager;
|
|
|
-import org.springframework.security.core.Authentication;
|
|
|
-import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
|
|
|
import org.springframework.security.oauth2.client.authentication.TestOAuth2AuthorizationCodeAuthenticationTokens;
|
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
|
|
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
|
|
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
|
|
|
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
|
|
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
|
|
-import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
|
|
|
-import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
|
|
|
-import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
|
|
|
+import org.springframework.util.CollectionUtils;
|
|
|
import org.springframework.web.server.handler.DefaultWebFilterChain;
|
|
|
import reactor.core.publisher.Mono;
|
|
|
|
|
|
+import java.util.Collections;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.LinkedHashMap;
|
|
|
+import java.util.Map;
|
|
|
+
|
|
|
import static org.assertj.core.api.Assertions.assertThatCode;
|
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
|
-import static org.mockito.Mockito.*;
|
|
|
+import static org.mockito.Mockito.times;
|
|
|
+import static org.mockito.Mockito.verify;
|
|
|
+import static org.mockito.Mockito.verifyNoInteractions;
|
|
|
+import static org.mockito.Mockito.when;
|
|
|
+import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests.request;
|
|
|
|
|
|
/**
|
|
|
* @author Rob Winch
|
|
@@ -102,52 +105,170 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
|
|
|
MockServerWebExchange exchange = MockServerWebExchange
|
|
|
.from(MockServerHttpRequest.get("/"));
|
|
|
DefaultWebFilterChain chain = new DefaultWebFilterChain(
|
|
|
- e -> e.getResponse().setComplete());
|
|
|
+ e -> e.getResponse().setComplete(), Collections.emptyList());
|
|
|
|
|
|
this.filter.filter(exchange, chain).block();
|
|
|
|
|
|
- verifyZeroInteractions(this.authenticationManager);
|
|
|
+ verifyNoInteractions(this.authenticationManager);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void filterWhenMatchThenAuthorizedClientSaved() {
|
|
|
- OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
|
|
|
- .redirectUri("/authorize/registration-id")
|
|
|
- .build();
|
|
|
- OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success()
|
|
|
- .redirectUri("/authorize/registration-id")
|
|
|
- .build();
|
|
|
- OAuth2AuthorizationExchange authorizationExchange =
|
|
|
- new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse);
|
|
|
- ClientRegistration registration = TestClientRegistrations.clientRegistration().build();
|
|
|
- Mono<Authentication> authentication = Mono.just(
|
|
|
- new OAuth2AuthorizationCodeAuthenticationToken(registration, authorizationExchange));
|
|
|
- OAuth2AuthorizationCodeAuthenticationToken authenticated = TestOAuth2AuthorizationCodeAuthenticationTokens
|
|
|
- .authenticated();
|
|
|
-
|
|
|
- when(this.authenticationManager.authenticate(any())).thenReturn(
|
|
|
- Mono.just(authenticated));
|
|
|
+ ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
|
|
|
+ when(this.clientRegistrationRepository.findByRegistrationId(any()))
|
|
|
+ .thenReturn(Mono.just(clientRegistration));
|
|
|
+
|
|
|
+ MockServerHttpRequest authorizationRequest =
|
|
|
+ createAuthorizationRequest("/authorization/callback");
|
|
|
+ OAuth2AuthorizationRequest oauth2AuthorizationRequest =
|
|
|
+ createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
|
|
|
when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
|
|
|
- .thenReturn(Mono.just(authorizationRequest));
|
|
|
+ .thenReturn(Mono.just(oauth2AuthorizationRequest));
|
|
|
+ when(this.authorizationRequestRepository.removeAuthorizationRequest(any()))
|
|
|
+ .thenReturn(Mono.just(oauth2AuthorizationRequest));
|
|
|
+
|
|
|
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any()))
|
|
|
.thenReturn(Mono.empty());
|
|
|
- ServerAuthenticationConverter converter = e -> authentication;
|
|
|
-
|
|
|
- this.filter = new OAuth2AuthorizationCodeGrantWebFilter(
|
|
|
- this.authenticationManager, converter, this.authorizedClientRepository);
|
|
|
- this.filter.setAuthorizationRequestRepository(this.authorizationRequestRepository);
|
|
|
+ when(this.authenticationManager.authenticate(any()))
|
|
|
+ .thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated()));
|
|
|
|
|
|
- MockServerHttpRequest request = MockServerHttpRequest
|
|
|
- .get("/authorize/registration-id")
|
|
|
- .queryParam(OAuth2ParameterNames.CODE, "code")
|
|
|
- .queryParam(OAuth2ParameterNames.STATE, "state")
|
|
|
- .build();
|
|
|
- MockServerWebExchange exchange = MockServerWebExchange.from(request);
|
|
|
+ MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
|
|
+ MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
|
|
|
DefaultWebFilterChain chain = new DefaultWebFilterChain(
|
|
|
- e -> e.getResponse().setComplete());
|
|
|
+ e -> e.getResponse().setComplete(), Collections.emptyList());
|
|
|
|
|
|
this.filter.filter(exchange, chain).block();
|
|
|
|
|
|
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(AnonymousAuthenticationToken.class), any());
|
|
|
}
|
|
|
+
|
|
|
+ // gh-7966
|
|
|
+ @Test
|
|
|
+ public void filterWhenAuthorizationRequestRedirectUriParametersMatchThenProcessed() {
|
|
|
+ ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
|
|
|
+ when(this.clientRegistrationRepository.findByRegistrationId(any()))
|
|
|
+ .thenReturn(Mono.just(clientRegistration));
|
|
|
+ when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any()))
|
|
|
+ .thenReturn(Mono.empty());
|
|
|
+ when(this.authenticationManager.authenticate(any()))
|
|
|
+ .thenReturn(Mono.just(TestOAuth2AuthorizationCodeAuthenticationTokens.authenticated()));
|
|
|
+
|
|
|
+ // 1) redirect_uri with query parameters
|
|
|
+ Map<String, String> parameters = new LinkedHashMap<>();
|
|
|
+ parameters.put("param1", "value1");
|
|
|
+ parameters.put("param2", "value2");
|
|
|
+ MockServerHttpRequest authorizationRequest =
|
|
|
+ createAuthorizationRequest("/authorization/callback", parameters);
|
|
|
+ OAuth2AuthorizationRequest oauth2AuthorizationRequest =
|
|
|
+ createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
|
|
|
+ when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
|
|
|
+ .thenReturn(Mono.just(oauth2AuthorizationRequest));
|
|
|
+ when(this.authorizationRequestRepository.removeAuthorizationRequest(any()))
|
|
|
+ .thenReturn(Mono.just(oauth2AuthorizationRequest));
|
|
|
+
|
|
|
+ MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
|
|
|
+ MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
|
|
|
+ DefaultWebFilterChain chain = new DefaultWebFilterChain(
|
|
|
+ e -> e.getResponse().setComplete(), Collections.emptyList());
|
|
|
+
|
|
|
+ this.filter.filter(exchange, chain).block();
|
|
|
+ verify(this.authenticationManager, times(1)).authenticate(any());
|
|
|
+
|
|
|
+ // 2) redirect_uri with query parameters AND authorization response additional parameters
|
|
|
+ Map<String, String> additionalParameters = new LinkedHashMap<>();
|
|
|
+ additionalParameters.put("auth-param1", "value1");
|
|
|
+ additionalParameters.put("auth-param2", "value2");
|
|
|
+ authorizationResponse = createAuthorizationResponse(authorizationRequest, additionalParameters);
|
|
|
+ exchange = MockServerWebExchange.from(authorizationResponse);
|
|
|
+
|
|
|
+ this.filter.filter(exchange, chain).block();
|
|
|
+ verify(this.authenticationManager, times(2)).authenticate(any());
|
|
|
+ }
|
|
|
+
|
|
|
+ // gh-7966
|
|
|
+ @Test
|
|
|
+ public void filterWhenAuthorizationRequestRedirectUriParametersNotMatchThenNotProcessed() {
|
|
|
+ String requestUri = "/authorization/callback";
|
|
|
+ Map<String, String> parameters = new LinkedHashMap<>();
|
|
|
+ parameters.put("param1", "value1");
|
|
|
+ parameters.put("param2", "value2");
|
|
|
+ MockServerHttpRequest authorizationRequest =
|
|
|
+ createAuthorizationRequest(requestUri, parameters);
|
|
|
+ ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
|
|
|
+ OAuth2AuthorizationRequest oauth2AuthorizationRequest =
|
|
|
+ createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
|
|
|
+ when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
|
|
|
+ .thenReturn(Mono.just(oauth2AuthorizationRequest));
|
|
|
+
|
|
|
+ // 1) Parameter value
|
|
|
+ Map<String, String> parametersNotMatch = new LinkedHashMap<>(parameters);
|
|
|
+ parametersNotMatch.put("param2", "value8");
|
|
|
+ MockServerHttpRequest authorizationResponse = createAuthorizationResponse(
|
|
|
+ createAuthorizationRequest(requestUri, parametersNotMatch));
|
|
|
+ MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
|
|
|
+ DefaultWebFilterChain chain = new DefaultWebFilterChain(
|
|
|
+ e -> e.getResponse().setComplete(), Collections.emptyList());
|
|
|
+
|
|
|
+ this.filter.filter(exchange, chain).block();
|
|
|
+ verifyNoInteractions(this.authenticationManager);
|
|
|
+
|
|
|
+ // 2) Parameter order
|
|
|
+ parametersNotMatch = new LinkedHashMap<>();
|
|
|
+ parametersNotMatch.put("param2", "value2");
|
|
|
+ parametersNotMatch.put("param1", "value1");
|
|
|
+ authorizationResponse = createAuthorizationResponse(
|
|
|
+ createAuthorizationRequest(requestUri, parametersNotMatch));
|
|
|
+ exchange = MockServerWebExchange.from(authorizationResponse);
|
|
|
+
|
|
|
+ this.filter.filter(exchange, chain).block();
|
|
|
+ verifyNoInteractions(this.authenticationManager);
|
|
|
+
|
|
|
+ // 3) Parameter missing
|
|
|
+ parametersNotMatch = new LinkedHashMap<>(parameters);
|
|
|
+ parametersNotMatch.remove("param2");
|
|
|
+ authorizationResponse = createAuthorizationResponse(
|
|
|
+ createAuthorizationRequest(requestUri, parametersNotMatch));
|
|
|
+ exchange = MockServerWebExchange.from(authorizationResponse);
|
|
|
+
|
|
|
+ this.filter.filter(exchange, chain).block();
|
|
|
+ verifyNoInteractions(this.authenticationManager);
|
|
|
+ }
|
|
|
+
|
|
|
+ private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(
|
|
|
+ MockServerHttpRequest authorizationRequest, ClientRegistration registration) {
|
|
|
+ Map<String, Object> attributes = new HashMap<>();
|
|
|
+ attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
|
|
|
+ return request()
|
|
|
+ .attributes(attributes)
|
|
|
+ .redirectUri(authorizationRequest.getURI().toString())
|
|
|
+ .build();
|
|
|
+ }
|
|
|
+
|
|
|
+ private static MockServerHttpRequest createAuthorizationRequest(String requestUri) {
|
|
|
+ return createAuthorizationRequest(requestUri, new LinkedHashMap<>());
|
|
|
+ }
|
|
|
+
|
|
|
+ private static MockServerHttpRequest createAuthorizationRequest(String requestUri, Map<String, String> parameters) {
|
|
|
+ MockServerHttpRequest.BaseBuilder<?> builder = MockServerHttpRequest
|
|
|
+ .get(requestUri);
|
|
|
+ if (!CollectionUtils.isEmpty(parameters)) {
|
|
|
+ parameters.forEach(builder::queryParam);
|
|
|
+ }
|
|
|
+ return builder.build();
|
|
|
+ }
|
|
|
+
|
|
|
+ private static MockServerHttpRequest createAuthorizationResponse(MockServerHttpRequest authorizationRequest) {
|
|
|
+ return createAuthorizationResponse(authorizationRequest, new LinkedHashMap<>());
|
|
|
+ }
|
|
|
+
|
|
|
+ private static MockServerHttpRequest createAuthorizationResponse(
|
|
|
+ MockServerHttpRequest authorizationRequest, Map<String, String> additionalParameters) {
|
|
|
+ MockServerHttpRequest.BaseBuilder<?> builder = MockServerHttpRequest
|
|
|
+ .get(authorizationRequest.getURI().toString());
|
|
|
+ builder.queryParam(OAuth2ParameterNames.CODE, "code");
|
|
|
+ builder.queryParam(OAuth2ParameterNames.STATE, "state");
|
|
|
+ additionalParameters.forEach(builder::queryParam);
|
|
|
+ builder.cookies(authorizationRequest.getCookies());
|
|
|
+ return builder.build();
|
|
|
+ }
|
|
|
}
|