Browse Source

Polish DefaultOAuth2AuthorizedClientManager

Josh Cummings 5 years ago
parent
commit
bb8706977d

+ 9 - 9
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java

@@ -15,6 +15,13 @@
  */
  */
 package org.springframework.security.oauth2.client.web;
 package org.springframework.security.oauth2.client.web;
 
 
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Function;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
 import org.springframework.lang.Nullable;
 import org.springframework.lang.Nullable;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
 import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
@@ -31,13 +38,6 @@ import org.springframework.util.StringUtils;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
 import org.springframework.web.context.request.ServletRequestAttributes;
 
 
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.function.Function;
-
 /**
 /**
  * The default implementation of an {@link OAuth2AuthorizedClientManager}.
  * The default implementation of an {@link OAuth2AuthorizedClientManager}.
  *
  *
@@ -84,13 +84,13 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
 		if (authorizedClient != null) {
 		if (authorizedClient != null) {
 			contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
 			contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
 		} else {
 		} else {
-			ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
-			Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'");
 			authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
 			authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
 					clientRegistrationId, principal, servletRequest);
 					clientRegistrationId, principal, servletRequest);
 			if (authorizedClient != null) {
 			if (authorizedClient != null) {
 				contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
 				contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
 			} else {
 			} else {
+				ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
+				Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'");
 				contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration);
 				contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration);
 			}
 			}
 		}
 		}

+ 29 - 17
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -15,6 +15,18 @@
  */
  */
 package org.springframework.security.oauth2.client.web.reactive.function.client;
 package org.springframework.security.oauth2.client.web.reactive.function.client;
 
 
+import java.net.URI;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Consumer;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
 import org.junit.After;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
@@ -23,6 +35,8 @@ import org.mockito.ArgumentCaptor;
 import org.mockito.Captor;
 import org.mockito.Captor;
 import org.mockito.Mock;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 import org.mockito.junit.MockitoJUnitRunner;
+import reactor.util.context.Context;
+
 import org.springframework.core.codec.ByteBufferEncoder;
 import org.springframework.core.codec.ByteBufferEncoder;
 import org.springframework.core.codec.CharSequenceEncoder;
 import org.springframework.core.codec.CharSequenceEncoder;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpHeaders;
@@ -76,26 +90,26 @@ import org.springframework.web.context.request.ServletRequestAttributes;
 import org.springframework.web.reactive.function.BodyInserter;
 import org.springframework.web.reactive.function.BodyInserter;
 import org.springframework.web.reactive.function.client.ClientRequest;
 import org.springframework.web.reactive.function.client.ClientRequest;
 import org.springframework.web.reactive.function.client.WebClient;
 import org.springframework.web.reactive.function.client.WebClient;
-import reactor.util.context.Context;
-
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import java.net.URI;
-import java.time.Duration;
-import java.time.Instant;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.function.Consumer;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyZeroInteractions;
+import static org.mockito.Mockito.when;
 import static org.springframework.http.HttpMethod.GET;
 import static org.springframework.http.HttpMethod.GET;
 import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
 import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
-import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.authentication;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getAuthentication;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getRequest;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.getResponse;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletRequest;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.httpServletResponse;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient;
 
 
 /**
 /**
  * @author Rob Winch
  * @author Rob Winch
@@ -603,8 +617,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
 		when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
 				eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);
 				eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);
 
 
-		when(this.clientRegistrationRepository.findByRegistrationId(eq(authentication.getAuthorizedClientRegistrationId()))).thenReturn(this.registration);
-
 		// Default request attributes set
 		// Default request attributes set
 		final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
 		final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
 				.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();
 				.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();