浏览代码

oidcLogin Test Configuration Flow

Fixes gh-7794
Josh Cummings 5 年之前
父节点
当前提交
09810b8df9

+ 9 - 9
test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java

@@ -471,7 +471,7 @@ public class SecurityMockServerConfigurers {
 		private OAuth2AccessToken accessToken;
 		private OAuth2AccessToken accessToken;
 		private OidcIdToken idToken;
 		private OidcIdToken idToken;
 		private OidcUserInfo userInfo;
 		private OidcUserInfo userInfo;
-		private OidcUser oidcUser;
+		private Supplier<OidcUser> oidcUser = this::defaultPrincipal;
 		private Collection<GrantedAuthority> authorities;
 		private Collection<GrantedAuthority> authorities;
 
 
 		ServerOAuth2AuthorizedClientRepository authorizedClientRepository =
 		ServerOAuth2AuthorizedClientRepository authorizedClientRepository =
@@ -491,6 +491,7 @@ public class SecurityMockServerConfigurers {
 		public OidcLoginMutator authorities(Collection<GrantedAuthority> authorities) {
 		public OidcLoginMutator authorities(Collection<GrantedAuthority> authorities) {
 			Assert.notNull(authorities, "authorities cannot be null");
 			Assert.notNull(authorities, "authorities cannot be null");
 			this.authorities = authorities;
 			this.authorities = authorities;
+			this.oidcUser = this::defaultPrincipal;
 			return this;
 			return this;
 		}
 		}
 
 
@@ -503,6 +504,7 @@ public class SecurityMockServerConfigurers {
 		public OidcLoginMutator authorities(GrantedAuthority... authorities) {
 		public OidcLoginMutator authorities(GrantedAuthority... authorities) {
 			Assert.notNull(authorities, "authorities cannot be null");
 			Assert.notNull(authorities, "authorities cannot be null");
 			this.authorities = Arrays.asList(authorities);
 			this.authorities = Arrays.asList(authorities);
+			this.oidcUser = this::defaultPrincipal;
 			return this;
 			return this;
 		}
 		}
 
 
@@ -517,6 +519,7 @@ public class SecurityMockServerConfigurers {
 			builder.subject("test-subject");
 			builder.subject("test-subject");
 			idTokenBuilderConsumer.accept(builder);
 			idTokenBuilderConsumer.accept(builder);
 			this.idToken = builder.build();
 			this.idToken = builder.build();
+			this.oidcUser = this::defaultPrincipal;
 			return this;
 			return this;
 		}
 		}
 
 
@@ -530,6 +533,7 @@ public class SecurityMockServerConfigurers {
 			OidcUserInfo.Builder builder = OidcUserInfo.builder();
 			OidcUserInfo.Builder builder = OidcUserInfo.builder();
 			userInfoBuilderConsumer.accept(builder);
 			userInfoBuilderConsumer.accept(builder);
 			this.userInfo = builder.build();
 			this.userInfo = builder.build();
+			this.oidcUser = this::defaultPrincipal;
 			return this;
 			return this;
 		}
 		}
 
 
@@ -543,7 +547,7 @@ public class SecurityMockServerConfigurers {
 		 * @return the {@link OidcLoginMutator} for further configuration
 		 * @return the {@link OidcLoginMutator} for further configuration
 		 */
 		 */
 		public OidcLoginMutator oidcUser(OidcUser oidcUser) {
 		public OidcLoginMutator oidcUser(OidcUser oidcUser) {
-			this.oidcUser = oidcUser;
+			this.oidcUser = () -> oidcUser;
 			return this;
 			return this;
 		}
 		}
 
 
@@ -601,7 +605,7 @@ public class SecurityMockServerConfigurers {
 		}
 		}
 
 
 		private OAuth2AuthenticationToken getToken() {
 		private OAuth2AuthenticationToken getToken() {
-			OidcUser oidcUser = getOidcUser();
+			OidcUser oidcUser = this.oidcUser.get();
 			return new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), this.clientRegistration.getRegistrationId());
 			return new OAuth2AuthenticationToken(oidcUser, oidcUser.getAuthorities(), this.clientRegistration.getRegistrationId());
 		}
 		}
 
 
@@ -634,12 +638,8 @@ public class SecurityMockServerConfigurers {
 			return this.userInfo;
 			return this.userInfo;
 		}
 		}
 
 
-		private OidcUser getOidcUser() {
-			if (this.oidcUser == null) {
-				return new DefaultOidcUser(getAuthorities(), getOidcIdToken(), this.userInfo);
-			} else {
-				return this.oidcUser;
-			}
+		private OidcUser defaultPrincipal() {
+			return new DefaultOidcUser(getAuthorities(), getOidcIdToken(), this.userInfo);
 		}
 		}
 	}
 	}
 }
 }

+ 11 - 12
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

@@ -1432,7 +1432,7 @@ public final class SecurityMockMvcRequestPostProcessors {
 		private OAuth2AccessToken accessToken;
 		private OAuth2AccessToken accessToken;
 		private OidcIdToken idToken;
 		private OidcIdToken idToken;
 		private OidcUserInfo userInfo;
 		private OidcUserInfo userInfo;
-		private OidcUser oidcUser;
+		private Supplier<OidcUser> oidcUser = this::defaultPrincipal;
 		private Collection<GrantedAuthority> authorities;
 		private Collection<GrantedAuthority> authorities;
 
 
 		private OidcLoginRequestPostProcessor(OAuth2AccessToken accessToken) {
 		private OidcLoginRequestPostProcessor(OAuth2AccessToken accessToken) {
@@ -1449,6 +1449,7 @@ public final class SecurityMockMvcRequestPostProcessors {
 		public OidcLoginRequestPostProcessor authorities(Collection<GrantedAuthority> authorities) {
 		public OidcLoginRequestPostProcessor authorities(Collection<GrantedAuthority> authorities) {
 			Assert.notNull(authorities, "authorities cannot be null");
 			Assert.notNull(authorities, "authorities cannot be null");
 			this.authorities = authorities;
 			this.authorities = authorities;
+			this.oidcUser = this::defaultPrincipal;
 			return this;
 			return this;
 		}
 		}
 
 
@@ -1461,6 +1462,7 @@ public final class SecurityMockMvcRequestPostProcessors {
 		public OidcLoginRequestPostProcessor authorities(GrantedAuthority... authorities) {
 		public OidcLoginRequestPostProcessor authorities(GrantedAuthority... authorities) {
 			Assert.notNull(authorities, "authorities cannot be null");
 			Assert.notNull(authorities, "authorities cannot be null");
 			this.authorities = Arrays.asList(authorities);
 			this.authorities = Arrays.asList(authorities);
+			this.oidcUser = this::defaultPrincipal;
 			return this;
 			return this;
 		}
 		}
 
 
@@ -1475,6 +1477,7 @@ public final class SecurityMockMvcRequestPostProcessors {
 			builder.subject("test-subject");
 			builder.subject("test-subject");
 			idTokenBuilderConsumer.accept(builder);
 			idTokenBuilderConsumer.accept(builder);
 			this.idToken = builder.build();
 			this.idToken = builder.build();
+			this.oidcUser = this::defaultPrincipal;
 			return this;
 			return this;
 		}
 		}
 
 
@@ -1488,20 +1491,19 @@ public final class SecurityMockMvcRequestPostProcessors {
 			OidcUserInfo.Builder builder = OidcUserInfo.builder();
 			OidcUserInfo.Builder builder = OidcUserInfo.builder();
 			userInfoBuilderConsumer.accept(builder);
 			userInfoBuilderConsumer.accept(builder);
 			this.userInfo = builder.build();
 			this.userInfo = builder.build();
+			this.oidcUser = this::defaultPrincipal;
 			return this;
 			return this;
 		}
 		}
 
 
 		/**
 		/**
 		 * Use the provided {@link OidcUser} as the authenticated user.
 		 * Use the provided {@link OidcUser} as the authenticated user.
 		 *
 		 *
-		 * Supplying an {@link OidcUser} will take precedence over {@link #idToken}, {@link #userInfo},
-		 * and list of {@link GrantedAuthority}s to use.
 		 *
 		 *
 		 * @param oidcUser the {@link OidcUser} to use
 		 * @param oidcUser the {@link OidcUser} to use
 		 * @return the {@link OidcLoginRequestPostProcessor} for further configuration
 		 * @return the {@link OidcLoginRequestPostProcessor} for further configuration
 		 */
 		 */
 		public OidcLoginRequestPostProcessor oidcUser(OidcUser oidcUser) {
 		public OidcLoginRequestPostProcessor oidcUser(OidcUser oidcUser) {
-			this.oidcUser = oidcUser;
+			this.oidcUser = () -> oidcUser;
 			return this;
 			return this;
 		}
 		}
 
 
@@ -1524,7 +1526,7 @@ public final class SecurityMockMvcRequestPostProcessors {
 
 
 		@Override
 		@Override
 		public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
 		public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
-			OidcUser oidcUser = getOidcUser();
+			OidcUser oidcUser = this.oidcUser.get();
 			return new OAuth2LoginRequestPostProcessor(this.accessToken)
 			return new OAuth2LoginRequestPostProcessor(this.accessToken)
 					.oauth2User(oidcUser)
 					.oauth2User(oidcUser)
 					.clientRegistration(this.clientRegistration)
 					.clientRegistration(this.clientRegistration)
@@ -1553,7 +1555,8 @@ public final class SecurityMockMvcRequestPostProcessors {
 
 
 		private OidcIdToken getOidcIdToken() {
 		private OidcIdToken getOidcIdToken() {
 			if (this.idToken == null) {
 			if (this.idToken == null) {
-				return new OidcIdToken("id-token", null, null, Collections.singletonMap(IdTokenClaimNames.SUB, "test-subject"));
+				return new OidcIdToken("id-token", null, null,
+						Collections.singletonMap(IdTokenClaimNames.SUB, "test-subject"));
 			} else {
 			} else {
 				return this.idToken;
 				return this.idToken;
 			}
 			}
@@ -1563,12 +1566,8 @@ public final class SecurityMockMvcRequestPostProcessors {
 			return this.userInfo;
 			return this.userInfo;
 		}
 		}
 
 
-		private OidcUser getOidcUser() {
-			if (this.oidcUser == null) {
-				return new DefaultOidcUser(getAuthorities(), getOidcIdToken(), this.userInfo);
-			} else {
-				return this.oidcUser;
-			}
+		private OidcUser defaultPrincipal() {
+			return new DefaultOidcUser(getAuthorities(), getOidcIdToken(), this.userInfo);
 		}
 		}
 	}
 	}
 
 

+ 32 - 0
test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersOidcLoginTests.java

@@ -27,6 +27,7 @@ import org.mockito.junit.MockitoJUnitRunner;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.MediaType;
 import org.springframework.http.MediaType;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
@@ -35,6 +36,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
 import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
 import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
 import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.test.web.reactive.server.WebTestClient;
@@ -42,6 +44,7 @@ import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.bind.annotation.RestController;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken;
 import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockOidcLogin;
 import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.mockOidcLogin;
 import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity;
 import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.springSecurity;
 
 
@@ -143,6 +146,35 @@ public class SecurityMockServerConfigurersOidcLoginTests extends AbstractMockSer
 				.containsEntry("email", "email@email");
 				.containsEntry("email", "email@email");
 	}
 	}
 
 
+	// gh-7794
+	@Test
+	public void oidcLoginWhenOidcUserSpecifiedThenLastCalledTakesPrecedence() throws Exception {
+		OidcUser oidcUser = new DefaultOidcUser(
+				AuthorityUtils.createAuthorityList("SCOPE_user"), idToken().build());
+
+		this.client.mutateWith(mockOidcLogin()
+				.idToken(i -> i.subject("foo"))
+				.oidcUser(oidcUser))
+				.get().uri("/token")
+				.exchange()
+				.expectStatus().isOk();
+
+		OAuth2AuthenticationToken token = this.controller.token;
+		assertThat(token.getPrincipal().getAttributes())
+				.containsEntry("sub", "subject");
+
+		this.client.mutateWith(mockOidcLogin()
+				.oidcUser(oidcUser)
+				.idToken(i -> i.subject("bar")))
+				.get().uri("/token")
+				.exchange()
+				.expectStatus().isOk();
+
+		token = this.controller.token;
+		assertThat(token.getPrincipal().getAttributes())
+				.containsEntry("sub", "bar");
+	}
+
 	@RestController
 	@RestController
 	static class OAuth2LoginController {
 	static class OAuth2LoginController {
 		volatile OAuth2AuthenticationToken token;
 		volatile OAuth2AuthenticationToken token;

+ 22 - 0
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsOidcLoginTests.java

@@ -31,12 +31,14 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.annotation.AuthenticationPrincipal;
 import org.springframework.security.core.annotation.AuthenticationPrincipal;
+import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.test.context.TestSecurityContextHolder;
 import org.springframework.security.test.context.TestSecurityContextHolder;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.ContextConfiguration;
@@ -51,6 +53,7 @@ import org.springframework.web.context.WebApplicationContext;
 import org.springframework.web.servlet.config.annotation.EnableWebMvc;
 import org.springframework.web.servlet.config.annotation.EnableWebMvc;
 
 
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
+import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.oidcLogin;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.oidcLogin;
 import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity;
 import static org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers.springSecurity;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
@@ -126,6 +129,25 @@ public class SecurityMockMvcRequestPostProcessorsOidcLoginTests {
 				.andExpect(content().string("email@email"));
 				.andExpect(content().string("email@email"));
 	}
 	}
 
 
+	// gh-7794
+	@Test
+	public void oidcLoginWhenOidcUserSpecifiedThenLastCalledTakesPrecedence() throws Exception {
+		OidcUser oidcUser = new DefaultOidcUser(
+				AuthorityUtils.createAuthorityList("SCOPE_user"), idToken().build());
+
+		this.mvc.perform(get("/id-token/sub")
+				.with(oidcLogin()
+						.idToken(i -> i.subject("foo"))
+						.oidcUser(oidcUser)))
+				.andExpect(status().isOk())
+				.andExpect(content().string("subject"));
+		this.mvc.perform(get("/id-token/sub")
+				.with(oidcLogin()
+						.oidcUser(oidcUser)
+						.idToken(i -> i.subject("bar"))))
+				.andExpect(content().string("bar"));
+	}
+
 	@EnableWebSecurity
 	@EnableWebSecurity
 	@EnableWebMvc
 	@EnableWebMvc
 	static class OAuth2LoginConfig extends WebSecurityConfigurerAdapter {
 	static class OAuth2LoginConfig extends WebSecurityConfigurerAdapter {