|
@@ -15,6 +15,18 @@
|
|
|
*/
|
|
|
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
|
|
|
|
|
+import java.net.URI;
|
|
|
+import java.time.Duration;
|
|
|
+import java.time.Instant;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.Optional;
|
|
|
+import java.util.function.Consumer;
|
|
|
+import javax.servlet.http.HttpServletRequest;
|
|
|
+import javax.servlet.http.HttpServletResponse;
|
|
|
+
|
|
|
import org.junit.After;
|
|
|
import org.junit.Before;
|
|
|
import org.junit.Test;
|
|
@@ -23,6 +35,8 @@ import org.mockito.ArgumentCaptor;
|
|
|
import org.mockito.Captor;
|
|
|
import org.mockito.Mock;
|
|
|
import org.mockito.junit.MockitoJUnitRunner;
|
|
|
+import reactor.util.context.Context;
|
|
|
+
|
|
|
import org.springframework.core.codec.ByteBufferEncoder;
|
|
|
import org.springframework.core.codec.CharSequenceEncoder;
|
|
|
import org.springframework.http.HttpHeaders;
|
|
@@ -76,26 +90,26 @@ import org.springframework.web.context.request.ServletRequestAttributes;
|
|
|
import org.springframework.web.reactive.function.BodyInserter;
|
|
|
import org.springframework.web.reactive.function.client.ClientRequest;
|
|
|
import org.springframework.web.reactive.function.client.WebClient;
|
|
|
-import reactor.util.context.Context;
|
|
|
-
|
|
|
-import javax.servlet.http.HttpServletRequest;
|
|
|
-import javax.servlet.http.HttpServletResponse;
|
|
|
-import java.net.URI;
|
|
|
-import java.time.Duration;
|
|
|
-import java.time.Instant;
|
|
|
-import java.util.ArrayList;
|
|
|
-import java.util.HashMap;
|
|
|
-import java.util.List;
|
|
|
-import java.util.Map;
|
|
|
-import java.util.Optional;
|
|
|
-import java.util.function.Consumer;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
|
|
|
-import static org.mockito.Mockito.*;
|
|
|
+import static org.mockito.Mockito.any;
|
|
|
+import static org.mockito.Mockito.eq;
|
|
|
+import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.never;
|
|
|
+import static org.mockito.Mockito.verify;
|
|
|
+import static org.mockito.Mockito.verifyZeroInteractions;
|
|
|
+import static org.mockito.Mockito.when;
|
|
|
import static org.springframework.http.HttpMethod.GET;
|
|
|
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
|
|
|
-import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
|
|
|
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY;
|
|
|
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.authentication;
|
|
|
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication;
|
|
|
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest;
|
|
|
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getResponse;
|
|
|
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest;
|
|
|
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse;
|
|
|
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient;
|
|
|
|
|
|
/**
|
|
|
* @author Rob Winch
|
|
@@ -603,8 +617,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
|
|
when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
|
|
|
eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);
|
|
|
|
|
|
- when(this.clientRegistrationRepository.findByRegistrationId(eq(authentication.getAuthorizedClientRegistrationId()))).thenReturn(this.registration);
|
|
|
-
|
|
|
// Default request attributes set
|
|
|
final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
|
|
|
.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();
|