2
0
Эх сурвалжийг харах

Replace ExpectedException @Rules with AssertJ

Replace JUnit ExpectedException @Rules with AssertJ calls.
Phillip Webb 5 жил өмнө
parent
commit
20baa7d409
24 өөрчлөгдсөн 384 нэмэгдсэн , 544 устгасан
  1. 6 11
      config/src/test/java/org/springframework/security/config/SecurityNamespaceHandlerTests.java
  2. 2 6
      config/src/test/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfigurationTests.java
  3. 9 17
      crypto/src/test/java/org/springframework/security/crypto/codec/HexTests.java
  4. 5 9
      itest/ldap/embedded-ldap-none/src/integration-test/java/org/springframework/security/LdapServerBeanDefinitionParserTests.java
  5. 4 32
      ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java
  6. 15 20
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java
  7. 27 31
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java
  8. 32 38
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java
  9. 25 26
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java
  10. 23 28
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java
  11. 46 46
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java
  12. 27 37
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2X509CredentialTests.java
  13. 19 21
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/Saml2X509CredentialTests.java
  14. 43 60
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java
  15. 2 9
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java
  16. 3 8
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java
  17. 10 14
      web/src/test/java/org/springframework/security/web/authentication/DelegatingAuthenticationFailureHandlerTests.java
  18. 2 8
      web/src/test/java/org/springframework/security/web/authentication/logout/CompositeLogoutHandlerTests.java
  19. 5 11
      web/src/test/java/org/springframework/security/web/authentication/logout/ForwardLogoutSuccessHandlerTests.java
  20. 3 8
      web/src/test/java/org/springframework/security/web/authentication/logout/HeaderWriterLogoutHandlerTests.java
  21. 3 8
      web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java
  22. 19 25
      web/src/test/java/org/springframework/security/web/firewall/FirewalledResponseTests.java
  23. 3 8
      web/src/test/java/org/springframework/security/web/header/writers/ClearSiteDataHeaderWriterTests.java
  24. 51 63
      web/src/test/java/org/springframework/security/web/server/authentication/SwitchUserWebFilterTests.java

+ 6 - 11
config/src/test/java/org/springframework/security/config/SecurityNamespaceHandlerTests.java

@@ -17,9 +17,7 @@
 package org.springframework.security.config;
 
 import org.apache.commons.logging.Log;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PowerMockIgnore;
@@ -48,9 +46,6 @@ import static org.mockito.Mockito.verifyZeroInteractions;
 @PowerMockIgnore({ "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*" })
 public class SecurityNamespaceHandlerTests {
 
-	@Rule
-	public ExpectedException thrown = ExpectedException.none();
-
 	// @formatter:off
 	private static final String XML_AUTHENTICATION_MANAGER = "<authentication-manager>"
 			+ "  <authentication-provider>"
@@ -103,12 +98,12 @@ public class SecurityNamespaceHandlerTests {
 	@Test
 	public void filterNoClassDefFoundError() throws Exception {
 		String className = "javax.servlet.Filter";
-		this.thrown.expect(BeanDefinitionParsingException.class);
-		this.thrown.expectMessage("NoClassDefFoundError: " + className);
 		PowerMockito.spy(ClassUtils.class);
 		PowerMockito.doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName",
 				eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class));
-		new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK);
+		assertThatExceptionOfType(BeanDefinitionParsingException.class)
+				.isThrownBy(() -> new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK))
+				.withMessageContaining("NoClassDefFoundError: " + className);
 	}
 
 	@Test
@@ -124,12 +119,12 @@ public class SecurityNamespaceHandlerTests {
 	@Test
 	public void filterChainProxyClassNotFoundException() throws Exception {
 		String className = FILTER_CHAIN_PROXY_CLASSNAME;
-		this.thrown.expect(BeanDefinitionParsingException.class);
-		this.thrown.expectMessage("ClassNotFoundException: " + className);
 		PowerMockito.spy(ClassUtils.class);
 		PowerMockito.doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName",
 				eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class));
-		new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK);
+		assertThatExceptionOfType(BeanDefinitionParsingException.class)
+				.isThrownBy(() -> new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK))
+				.withMessageContaining("ClassNotFoundException: " + className);
 	}
 
 	@Test

+ 2 - 6
config/src/test/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfigurationTests.java

@@ -25,7 +25,6 @@ import javax.sql.DataSource;
 import org.aopalliance.intercept.MethodInterceptor;
 import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 
 import org.springframework.beans.BeansException;
@@ -80,9 +79,6 @@ public class GlobalMethodSecurityConfigurationTests {
 	@Rule
 	public final SpringTestRule spring = new SpringTestRule();
 
-	@Rule
-	public ExpectedException thrown = ExpectedException.none();
-
 	@Autowired(required = false)
 	private MethodSecurityService service;
 
@@ -98,8 +94,8 @@ public class GlobalMethodSecurityConfigurationTests {
 
 	@Test
 	public void configureWhenGlobalMethodSecurityIsMissingMetadataSourceThenException() {
-		this.thrown.expect(UnsatisfiedDependencyException.class);
-		this.spring.register(IllegalStateGlobalMethodSecurityConfig.class).autowire();
+		assertThatExceptionOfType(UnsatisfiedDependencyException.class)
+				.isThrownBy(() -> this.spring.register(IllegalStateGlobalMethodSecurityConfig.class).autowire());
 	}
 
 	@Test

+ 9 - 17
crypto/src/test/java/org/springframework/security/crypto/codec/HexTests.java

@@ -16,11 +16,10 @@
 
 package org.springframework.security.crypto.codec;
 
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 
 /**
  * Test cases for {@link Hex}.
@@ -29,9 +28,6 @@ import static org.assertj.core.api.Assertions.assertThat;
  */
 public class HexTests {
 
-	@Rule
-	public ExpectedException expectedException = ExpectedException.none();
-
 	@Test
 	public void encode() {
 		assertThat(Hex.encode(new byte[] { (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D' }))
@@ -55,30 +51,26 @@ public class HexTests {
 
 	@Test
 	public void decodeNotEven() {
-		this.expectedException.expect(IllegalArgumentException.class);
-		this.expectedException.expectMessage("Hex-encoded string must have an even number of characters");
-		Hex.decode("414243444");
+		assertThatIllegalArgumentException().isThrownBy(() -> Hex.decode("414243444"))
+				.withMessage("Hex-encoded string must have an even number of characters");
 	}
 
 	@Test
 	public void decodeExistNonHexCharAtFirst() {
-		this.expectedException.expect(IllegalArgumentException.class);
-		this.expectedException.expectMessage("Detected a Non-hex character at 1 or 2 position");
-		Hex.decode("G0");
+		assertThatIllegalArgumentException().isThrownBy(() -> Hex.decode("G0"))
+				.withMessage("Detected a Non-hex character at 1 or 2 position");
 	}
 
 	@Test
 	public void decodeExistNonHexCharAtSecond() {
-		this.expectedException.expect(IllegalArgumentException.class);
-		this.expectedException.expectMessage("Detected a Non-hex character at 3 or 4 position");
-		Hex.decode("410G");
+		assertThatIllegalArgumentException().isThrownBy(() -> Hex.decode("410G"))
+				.withMessage("Detected a Non-hex character at 3 or 4 position");
 	}
 
 	@Test
 	public void decodeExistNonHexCharAtBoth() {
-		this.expectedException.expect(IllegalArgumentException.class);
-		this.expectedException.expectMessage("Detected a Non-hex character at 5 or 6 position");
-		Hex.decode("4142GG");
+		assertThatIllegalArgumentException().isThrownBy(() -> Hex.decode("4142GG"))
+				.withMessage("Detected a Non-hex character at 5 or 6 position");
 	}
 
 }

+ 5 - 9
itest/ldap/embedded-ldap-none/src/integration-test/java/org/springframework/security/LdapServerBeanDefinitionParserTests.java

@@ -17,21 +17,18 @@
 package org.springframework.security;
 
 import org.junit.After;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import org.springframework.beans.factory.BeanDefinitionStoreException;
 import org.springframework.context.support.ClassPathXmlApplicationContext;
 
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+
 /**
  * @author Eddú Meléndez
  */
 public class LdapServerBeanDefinitionParserTests {
 
-	@Rule
-	public ExpectedException thrown = ExpectedException.none();
-
 	private ClassPathXmlApplicationContext context;
 
 	@After
@@ -44,10 +41,9 @@ public class LdapServerBeanDefinitionParserTests {
 
 	@Test
 	public void apacheDirectoryServerIsStartedByDefault() {
-		this.thrown.expect(BeanDefinitionStoreException.class);
-		this.thrown.expectMessage("Embedded LDAP server is not provided");
-
-		this.context = new ClassPathXmlApplicationContext("applicationContext-security.xml");
+		assertThatExceptionOfType(BeanDefinitionStoreException.class)
+				.isThrownBy(() -> this.context = new ClassPathXmlApplicationContext("applicationContext-security.xml"))
+				.withMessageContaining("Embedded LDAP server is not provided");
 	}
 
 }

+ 4 - 32
ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java

@@ -30,14 +30,8 @@ import javax.naming.directory.SearchControls;
 import javax.naming.directory.SearchResult;
 
 import org.apache.directory.shared.ldap.util.EmptyEnumeration;
-import org.hamcrest.BaseMatcher;
-import org.hamcrest.CoreMatchers;
-import org.hamcrest.Description;
-import org.hamcrest.Matcher;
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 import org.mockito.ArgumentCaptor;
 
 import org.springframework.dao.IncorrectResultSizeDataAccessException;
@@ -71,9 +65,6 @@ public class ActiveDirectoryLdapAuthenticationProviderTests {
 
 	public static final String NON_EXISTING_LDAP_PROVIDER = "ldap://192.168.1.201/";
 
-	@Rule
-	public ExpectedException thrown = ExpectedException.none();
-
 	ActiveDirectoryLdapAuthenticationProvider provider;
 
 	UsernamePasswordAuthenticationToken joe = new UsernamePasswordAuthenticationToken("joe", "password");
@@ -245,29 +236,10 @@ public class ActiveDirectoryLdapAuthenticationProviderTests {
 		this.provider.contextFactory = createContextFactoryThrowing(
 				new AuthenticationException(msg + dataCode + ", xxxx]"));
 		this.provider.setConvertSubErrorCodesToExceptions(true);
-		this.thrown.expect(BadCredentialsException.class);
-		this.thrown.expect(new BaseMatcher<BadCredentialsException>() {
-			private Matcher<Object> causeInstance = CoreMatchers
-					.instanceOf(ActiveDirectoryAuthenticationException.class);
-
-			private Matcher<String> causeDataCode = CoreMatchers.equalTo(dataCode);
-
-			@Override
-			public boolean matches(Object that) {
-				Throwable t = (Throwable) that;
-				ActiveDirectoryAuthenticationException cause = (ActiveDirectoryAuthenticationException) t.getCause();
-				return this.causeInstance.matches(cause) && this.causeDataCode.matches(cause.getDataCode());
-			}
-
-			@Override
-			public void describeTo(Description desc) {
-				desc.appendText("getCause() ");
-				this.causeInstance.describeTo(desc);
-				desc.appendText("getCause().getDataCode() ");
-				this.causeDataCode.describeTo(desc);
-			}
-		});
-		this.provider.authenticate(this.joe);
+		assertThatExceptionOfType(BadCredentialsException.class).isThrownBy(() -> this.provider.authenticate(this.joe))
+				.withCauseInstanceOf(ActiveDirectoryAuthenticationException.class)
+				.satisfies((ex) -> assertThat(((ActiveDirectoryAuthenticationException) ex.getCause()).getDataCode())
+						.isEqualTo(dataCode));
 	}
 
 	@Test(expected = CredentialsExpiredException.class)

+ 15 - 20
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java

@@ -25,9 +25,7 @@ import java.util.Map;
 import java.util.Set;
 
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 import org.mockito.ArgumentCaptor;
 import org.mockito.stubbing.Answer;
 
@@ -52,7 +50,8 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization
 import org.springframework.security.oauth2.core.user.OAuth2User;
 
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.hamcrest.CoreMatchers.containsString;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyCollection;
 import static org.mockito.BDDMockito.given;
@@ -79,9 +78,6 @@ public class OAuth2LoginAuthenticationProviderTests {
 
 	private OAuth2LoginAuthenticationProvider authenticationProvider;
 
-	@Rule
-	public ExpectedException exception = ExpectedException.none();
-
 	@Before
 	@SuppressWarnings("unchecked")
 	public void setUp() {
@@ -98,20 +94,19 @@ public class OAuth2LoginAuthenticationProviderTests {
 
 	@Test
 	public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		new OAuth2LoginAuthenticationProvider(null, this.userService);
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new OAuth2LoginAuthenticationProvider(null, this.userService));
 	}
 
 	@Test
 	public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, null);
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, null));
 	}
 
 	@Test
 	public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.authenticationProvider.setAuthoritiesMapper(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.authenticationProvider.setAuthoritiesMapper(null));
 	}
 
 	@Test
@@ -132,26 +127,26 @@ public class OAuth2LoginAuthenticationProviderTests {
 
 	@Test
 	public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_REQUEST));
 		OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error()
 				.errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build();
 		OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
 				authorizationResponse);
-		this.authenticationProvider
-				.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.authenticationProvider.authenticate(
+						new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)))
+				.withMessageContaining(OAuth2ErrorCodes.INVALID_REQUEST);
 	}
 
 	@Test
 	public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_state_parameter"));
 		OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success().state("67890")
 				.build();
 		OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
 				authorizationResponse);
-		this.authenticationProvider
-				.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.authenticationProvider.authenticate(
+						new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)))
+				.withMessageContaining("invalid_state_parameter");
 	}
 
 	@Test

+ 27 - 31
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java

@@ -21,9 +21,7 @@ import java.time.Instant;
 import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.MediaType;
@@ -40,7 +38,8 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization
 import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
 
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.hamcrest.CoreMatchers.containsString;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 
 /**
  * Tests for {@link NimbusAuthorizationCodeTokenResponseClient}.
@@ -59,9 +58,6 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
 
 	private NimbusAuthorizationCodeTokenResponseClient tokenResponseClient = new NimbusAuthorizationCodeTokenResponseClient();
 
-	@Rule
-	public ExpectedException exception = ExpectedException.none();
-
 	@Before
 	public void setUp() {
 		this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration()
@@ -109,29 +105,27 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
 
 	@Test
 	public void getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
 		String redirectUri = "http:\\example.com";
 		OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
 				.redirectUri(redirectUri).build();
 		OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
 				this.authorizationResponse);
-		this.tokenResponseClient.getTokenResponse(
-				new OAuth2AuthorizationCodeGrantRequest(this.clientRegistrationBuilder.build(), authorizationExchange));
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
+						this.clientRegistrationBuilder.build(), authorizationExchange)));
 	}
 
 	@Test
 	public void getTokenResponseWhenTokenUriMalformedThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
 		String tokenUri = "http:\\provider.com\\oauth2\\token";
 		this.clientRegistrationBuilder.tokenUri(tokenUri);
-		this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
-				this.clientRegistrationBuilder.build(), this.authorizationExchange));
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
+						this.clientRegistrationBuilder.build(), this.authorizationExchange)));
 	}
 
 	@Test
 	public void getTokenResponseWhenSuccessResponseInvalidThenThrowOAuth2AuthorizationException() throws Exception {
-		this.exception.expect(OAuth2AuthorizationException.class);
-		this.exception.expectMessage(containsString("invalid_token_response"));
 		MockWebServer server = new MockWebServer();
 		// @formatter:off
 		String accessTokenSuccessResponse = "{\n"
@@ -149,8 +143,10 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
 		String tokenUri = server.url("/oauth2/token").toString();
 		this.clientRegistrationBuilder.tokenUri(tokenUri);
 		try {
-			this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
-					this.clientRegistrationBuilder.build(), this.authorizationExchange));
+			assertThatExceptionOfType(OAuth2AuthorizationException.class)
+					.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
+							this.clientRegistrationBuilder.build(), this.authorizationExchange)))
+					.withMessageContaining("invalid_token_response");
 		}
 		finally {
 			server.shutdown();
@@ -159,17 +155,15 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
 
 	@Test
 	public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthorizationException() {
-		this.exception.expect(OAuth2AuthorizationException.class);
 		String tokenUri = "https://invalid-provider.com/oauth2/token";
 		this.clientRegistrationBuilder.tokenUri(tokenUri);
-		this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
-				this.clientRegistrationBuilder.build(), this.authorizationExchange));
+		assertThatExceptionOfType(OAuth2AuthorizationException.class)
+				.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
+						this.clientRegistrationBuilder.build(), this.authorizationExchange)));
 	}
 
 	@Test
 	public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() throws Exception {
-		this.exception.expect(OAuth2AuthorizationException.class);
-		this.exception.expectMessage(containsString("unauthorized_client"));
 		MockWebServer server = new MockWebServer();
 		// @formatter:off
 		String accessTokenErrorResponse = "{\n"
@@ -182,8 +176,10 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
 		String tokenUri = server.url("/oauth2/token").toString();
 		this.clientRegistrationBuilder.tokenUri(tokenUri);
 		try {
-			this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
-					this.clientRegistrationBuilder.build(), this.authorizationExchange));
+			assertThatExceptionOfType(OAuth2AuthorizationException.class)
+					.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
+							this.clientRegistrationBuilder.build(), this.authorizationExchange)))
+					.withMessageContaining("unauthorized_client");
 		}
 		finally {
 			server.shutdown();
@@ -193,16 +189,16 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
 	// gh-5594
 	@Test
 	public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() throws Exception {
-		this.exception.expect(OAuth2AuthorizationException.class);
-		this.exception.expectMessage(containsString("server_error"));
 		MockWebServer server = new MockWebServer();
 		server.enqueue(new MockResponse().setResponseCode(500));
 		server.start();
 		String tokenUri = server.url("/oauth2/token").toString();
 		this.clientRegistrationBuilder.tokenUri(tokenUri);
 		try {
-			this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
-					this.clientRegistrationBuilder.build(), this.authorizationExchange));
+			assertThatExceptionOfType(OAuth2AuthorizationException.class)
+					.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
+							this.clientRegistrationBuilder.build(), this.authorizationExchange)))
+					.withMessageContaining("server_error");
 		}
 		finally {
 			server.shutdown();
@@ -212,8 +208,6 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
 	@Test
 	public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException()
 			throws Exception {
-		this.exception.expect(OAuth2AuthorizationException.class);
-		this.exception.expectMessage(containsString("invalid_token_response"));
 		MockWebServer server = new MockWebServer();
 		// @formatter:off
 		String accessTokenSuccessResponse = "{\n"
@@ -228,8 +222,10 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
 		String tokenUri = server.url("/oauth2/token").toString();
 		this.clientRegistrationBuilder.tokenUri(tokenUri);
 		try {
-			this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
-					this.clientRegistrationBuilder.build(), this.authorizationExchange));
+			assertThatExceptionOfType(OAuth2AuthorizationException.class)
+					.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest(
+							this.clientRegistrationBuilder.build(), this.authorizationExchange)))
+					.withMessageContaining("invalid_token_response");
 		}
 		finally {
 			server.shutdown();

+ 32 - 38
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java

@@ -28,9 +28,7 @@ import java.util.Map;
 import java.util.Set;
 
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 import org.mockito.ArgumentCaptor;
 import org.mockito.stubbing.Answer;
 
@@ -64,7 +62,8 @@ import org.springframework.security.oauth2.jwt.JwtException;
 import org.springframework.security.oauth2.jwt.TestJwts;
 
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.hamcrest.CoreMatchers.containsString;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyCollection;
 import static org.mockito.ArgumentMatchers.anyString;
@@ -100,9 +99,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 
 	private String nonceHash;
 
-	@Rule
-	public ExpectedException exception = ExpectedException.none();
-
 	@Before
 	@SuppressWarnings("unchecked")
 	public void setUp() {
@@ -138,26 +134,24 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 
 	@Test
 	public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		new OidcAuthorizationCodeAuthenticationProvider(null, this.userService);
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new OidcAuthorizationCodeAuthenticationProvider(null, this.userService));
 	}
 
 	@Test
 	public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, null);
+		assertThatIllegalArgumentException().isThrownBy(
+				() -> new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, null));
 	}
 
 	@Test
 	public void setJwtDecoderFactoryWhenNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.authenticationProvider.setJwtDecoderFactory(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.authenticationProvider.setJwtDecoderFactory(null));
 	}
 
 	@Test
 	public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.authenticationProvider.setAuthoritiesMapper(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.authenticationProvider.setAuthoritiesMapper(null));
 	}
 
 	@Test
@@ -181,8 +175,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 
 	@Test
 	public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_SCOPE));
 		// @formatter:off
 		OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error()
 				.errorCode(OAuth2ErrorCodes.INVALID_SCOPE)
@@ -190,14 +182,14 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 		// @formatter:on
 		OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
 				authorizationResponse);
-		this.authenticationProvider
-				.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.authenticationProvider.authenticate(
+						new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)))
+				.withMessageContaining(OAuth2ErrorCodes.INVALID_SCOPE);
 	}
 
 	@Test
 	public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_state_parameter"));
 		// @formatter:off
 		OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success()
 				.state("89012")
@@ -205,14 +197,14 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 		// @formatter:on
 		OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
 				authorizationResponse);
-		this.authenticationProvider
-				.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.authenticationProvider.authenticate(
+						new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)))
+				.withMessageContaining("invalid_state_parameter");
 	}
 
 	@Test
 	public void authenticateWhenTokenResponseDoesNotContainIdTokenThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_id_token"));
 		// @formatter:off
 		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
 				.withResponse(this.accessTokenSuccessResponse())
@@ -220,38 +212,38 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 				.build();
 		// @formatter:on
 		given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
-		this.authenticationProvider
-				.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.authenticationProvider.authenticate(
+						new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)))
+				.withMessageContaining("invalid_id_token");
 	}
 
 	@Test
 	public void authenticateWhenJwkSetUriNotSetThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("missing_signature_verifier"));
 		// @formatter:off
 		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
 				.jwkSetUri(null)
 				.build();
 		// @formatter:on
-		this.authenticationProvider
-				.authenticate(new OAuth2LoginAuthenticationToken(clientRegistration, this.authorizationExchange));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.authenticationProvider.authenticate(
+						new OAuth2LoginAuthenticationToken(clientRegistration, this.authorizationExchange)))
+				.withMessageContaining("missing_signature_verifier");
 	}
 
 	@Test
 	public void authenticateWhenIdTokenValidationErrorThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("[invalid_id_token] ID Token Validation Error"));
 		JwtDecoder jwtDecoder = mock(JwtDecoder.class);
 		given(jwtDecoder.decode(anyString())).willThrow(new JwtException("ID Token Validation Error"));
 		this.authenticationProvider.setJwtDecoderFactory((registration) -> jwtDecoder);
-		this.authenticationProvider
-				.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.authenticationProvider.authenticate(
+						new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)))
+				.withMessageContaining("[invalid_id_token] ID Token Validation Error");
 	}
 
 	@Test
 	public void authenticateWhenIdTokenInvalidNonceThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("[invalid_nonce]"));
 		Map<String, Object> claims = new HashMap<>();
 		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
 		claims.put(IdTokenClaimNames.SUB, "subject1");
@@ -259,8 +251,10 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 		claims.put(IdTokenClaimNames.AZP, "client1");
 		claims.put(IdTokenClaimNames.NONCE, "invalid-nonce-hash");
 		this.setUpIdToken(claims);
-		this.authenticationProvider
-				.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.authenticationProvider.authenticate(
+						new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)))
+				.withMessageContaining("[invalid_nonce]");
 	}
 
 	@Test

+ 25 - 26
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java

@@ -29,9 +29,7 @@ import okhttp3.mockwebserver.MockWebServer;
 import okhttp3.mockwebserver.RecordedRequest;
 import org.junit.After;
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
@@ -56,8 +54,8 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
-import static org.hamcrest.CoreMatchers.containsString;
 import static org.mockito.ArgumentMatchers.same;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
@@ -80,9 +78,6 @@ public class OidcUserServiceTests {
 
 	private MockWebServer server;
 
-	@Rule
-	public ExpectedException exception = ExpectedException.none();
-
 	@Before
 	public void setup() throws Exception {
 		this.server = new MockWebServer();
@@ -133,8 +128,7 @@ public class OidcUserServiceTests {
 
 	@Test
 	public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.userService.loadUser(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
 	}
 
 	@Test
@@ -260,8 +254,6 @@ public class OidcUserServiceTests {
 	// gh-5447
 	@Test
 	public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_user_info_response"));
 		// @formatter:off
 		String userInfoResponse = "{\n"
 				+ "   \"email\": \"full_name@provider.com\",\n"
@@ -272,25 +264,26 @@ public class OidcUserServiceTests {
 		String userInfoUri = this.server.url("/user").toString();
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
 				.userNameAttributeName(StandardClaimNames.EMAIL).build();
-		this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.userService
+						.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)))
+				.withMessageContaining("invalid_user_info_response");
 	}
 
 	@Test
 	public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_user_info_response"));
 		String userInfoResponse = "{\n" + "	\"sub\": \"other-subject\"\n" + "}\n";
 		this.server.enqueue(jsonResponse(userInfoResponse));
 		String userInfoUri = this.server.url("/user").toString();
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
-		this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.userService
+						.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)))
+				.withMessageContaining("invalid_user_info_response");
 	}
 
 	@Test
 	public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(
-				"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
 		// @formatter:off
 		String userInfoResponse = "{\n"
 			+ "   \"sub\": \"subject1\",\n"
@@ -304,28 +297,34 @@ public class OidcUserServiceTests {
 		this.server.enqueue(jsonResponse(userInfoResponse));
 		String userInfoUri = this.server.url("/user").toString();
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
-		this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.userService
+						.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)))
+				.withMessageContaining(
+						"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource");
 	}
 
 	@Test
 	public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(
-				"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
 		this.server.enqueue(new MockResponse().setResponseCode(500));
 		String userInfoUri = this.server.url("/user").toString();
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
-		this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.userService
+						.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)))
+				.withMessageContaining(
+						"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error");
 	}
 
 	@Test
 	public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(
-				"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
 		String userInfoUri = "https://invalid-provider.com/user";
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
-		this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(() -> this.userService
+						.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)))
+				.withMessageContaining(
+						"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource");
 	}
 
 	@Test

+ 23 - 28
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java

@@ -26,9 +26,7 @@ import okhttp3.mockwebserver.MockResponse;
 import okhttp3.mockwebserver.MockWebServer;
 import org.junit.After;
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.MediaType;
@@ -43,7 +41,8 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.hamcrest.CoreMatchers.containsString;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 
 /**
  * Tests for {@link CustomUserTypesOAuth2UserService}.
@@ -61,9 +60,6 @@ public class CustomUserTypesOAuth2UserServiceTests {
 
 	private MockWebServer server;
 
-	@Rule
-	public ExpectedException exception = ExpectedException.none();
-
 	@Before
 	public void setUp() throws Exception {
 		this.server = new MockWebServer();
@@ -86,32 +82,28 @@ public class CustomUserTypesOAuth2UserServiceTests {
 
 	@Test
 	public void constructorWhenCustomUserTypesIsNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		new CustomUserTypesOAuth2UserService(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> new CustomUserTypesOAuth2UserService(null));
 	}
 
 	@Test
 	public void constructorWhenCustomUserTypesIsEmptyThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		new CustomUserTypesOAuth2UserService(Collections.emptyMap());
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new CustomUserTypesOAuth2UserService(Collections.emptyMap()));
 	}
 
 	@Test
 	public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.userService.setRequestEntityConverter(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setRequestEntityConverter(null));
 	}
 
 	@Test
 	public void setRestOperationsWhenNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.userService.setRestOperations(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setRestOperations(null));
 	}
 
 	@Test
 	public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.userService.loadUser(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
 	}
 
 	@Test
@@ -151,9 +143,6 @@ public class CustomUserTypesOAuth2UserServiceTests {
 
 	@Test
 	public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(
-				"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
 		// @formatter:off
 		String userInfoResponse = "{\n"
 			+ "   \"id\": \"12345\",\n"
@@ -166,28 +155,34 @@ public class CustomUserTypesOAuth2UserServiceTests {
 		this.server.enqueue(jsonResponse(userInfoResponse));
 		String userInfoUri = this.server.url("/user").toString();
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
-		this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(
+						() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
+				.withMessageContaining(
+						"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource");
 	}
 
 	@Test
 	public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(
-				"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
 		this.server.enqueue(new MockResponse().setResponseCode(500));
 		String userInfoUri = this.server.url("/user").toString();
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
-		this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(
+						() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
+				.withMessageContaining(
+						"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error");
 	}
 
 	@Test
 	public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(
-				"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
 		String userInfoUri = "https://invalid-provider.com/user";
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
-		this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(
+						() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
+				.withMessageContaining(
+						"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource");
 	}
 
 	private ClientRegistration.Builder withRegistrationId(String registrationId) {

+ 46 - 46
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java

@@ -26,9 +26,7 @@ import okhttp3.mockwebserver.MockWebServer;
 import okhttp3.mockwebserver.RecordedRequest;
 import org.junit.After;
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import org.springframework.core.ParameterizedTypeReference;
 import org.springframework.core.convert.converter.Converter;
@@ -51,7 +49,8 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
 import org.springframework.web.client.RestOperations;
 
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.hamcrest.CoreMatchers.containsString;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.nullable;
 import static org.mockito.BDDMockito.given;
@@ -73,9 +72,6 @@ public class DefaultOAuth2UserServiceTests {
 
 	private MockWebServer server;
 
-	@Rule
-	public ExpectedException exception = ExpectedException.none();
-
 	@Before
 	public void setup() throws Exception {
 		this.server = new MockWebServer();
@@ -95,40 +91,39 @@ public class DefaultOAuth2UserServiceTests {
 
 	@Test
 	public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.userService.setRequestEntityConverter(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setRequestEntityConverter(null));
 	}
 
 	@Test
 	public void setRestOperationsWhenNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.userService.setRestOperations(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setRestOperations(null));
 	}
 
 	@Test
 	public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.userService.loadUser(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null));
 	}
 
 	@Test
 	public void loadUserWhenUserInfoUriIsNullThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("missing_user_info_uri"));
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.build();
-		this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(
+						() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
+				.withMessageContaining("missing_user_info_uri");
 	}
 
 	@Test
 	public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("missing_user_name_attribute"));
 		// @formatter:off
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder
 				.userInfoUri("https://provider.com/user")
 				.build();
 		// @formatter:on
-		this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(
+						() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
+				.withMessageContaining("missing_user_name_attribute");
 	}
 
 	@Test
@@ -165,9 +160,6 @@ public class DefaultOAuth2UserServiceTests {
 
 	@Test
 	public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(
-				"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
 		// @formatter:off
 		String userInfoResponse = "{\n"
 			+ "	\"user-name\": \"user1\",\n"
@@ -182,16 +174,15 @@ public class DefaultOAuth2UserServiceTests {
 		String userInfoUri = this.server.url("/user").toString();
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
 				.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build();
-		this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(
+						() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
+				.withMessageContaining(
+						"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource");
 	}
 
 	@Test
 	public void loadUserWhenUserInfoErrorResponseWwwAuthenticateHeaderThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(
-				"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
-		this.exception.expectMessage(
-				containsString("Error Code: insufficient_scope, Error Description: The access token expired"));
 		String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\"";
 		MockResponse response = new MockResponse();
 		response.setHeader(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader);
@@ -200,15 +191,16 @@ public class DefaultOAuth2UserServiceTests {
 		String userInfoUri = this.server.url("/user").toString();
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
 				.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build();
-		this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(
+						() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
+				.withMessageContaining(
+						"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")
+				.withMessageContaining("Error Code: insufficient_scope, Error Description: The access token expired");
 	}
 
 	@Test
 	public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(
-				"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
-		this.exception.expectMessage(containsString("Error Code: invalid_token"));
 		// @formatter:off
 		String userInfoErrorResponse = "{\n"
 				+ "   \"error\": \"invalid_token\"\n"
@@ -218,30 +210,37 @@ public class DefaultOAuth2UserServiceTests {
 		String userInfoUri = this.server.url("/user").toString();
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
 				.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build();
-		this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(
+						() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
+				.withMessageContaining(
+						"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")
+				.withMessageContaining("Error Code: invalid_token");
 	}
 
 	@Test
 	public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(
-				"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"));
 		this.server.enqueue(new MockResponse().setResponseCode(500));
 		String userInfoUri = this.server.url("/user").toString();
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
 				.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build();
-		this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(
+						() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
+				.withMessageContaining(
+						"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error");
 	}
 
 	@Test
 	public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(
-				"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
 		String userInfoUri = "https://invalid-provider.com/user";
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
 				.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build();
-		this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(
+						() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
+				.withMessageContaining(
+						"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource");
 	}
 
 	// gh-5294
@@ -348,17 +347,18 @@ public class DefaultOAuth2UserServiceTests {
 	@Test
 	public void loadUserWhenUserInfoSuccessResponseInvalidContentTypeThenThrowOAuth2AuthenticationException() {
 		String userInfoUri = this.server.url("/user").toString();
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString(
-				"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource "
-						+ "from '" + userInfoUri + "': response contains invalid content type 'text/plain'."));
 		MockResponse response = new MockResponse();
 		response.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN_VALUE);
 		response.setBody("invalid content type");
 		this.server.enqueue(response);
 		ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
 				.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build();
-		this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
+		assertThatExceptionOfType(OAuth2AuthenticationException.class)
+				.isThrownBy(
+						() -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)))
+				.withMessageContaining(
+						"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource "
+								+ "from '" + userInfoUri + "': response contains invalid content type 'text/plain'.");
 	}
 
 	private DefaultOAuth2UserService withMockResponse(Map<String, Object> response) {

+ 27 - 37
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2X509CredentialTests.java

@@ -23,17 +23,15 @@ import java.security.cert.CertificateFactory;
 import java.security.cert.X509Certificate;
 
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import org.springframework.security.converter.RsaKeyConverters;
 import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType;
 
-public class Saml2X509CredentialTests {
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
 
-	@Rule
-	public ExpectedException exception = ExpectedException.none();
+public class Saml2X509CredentialTests {
 
 	private PrivateKey key;
 
@@ -99,98 +97,90 @@ public class Saml2X509CredentialTests {
 
 	@Test
 	public void constructorWhenRelyingPartyWithoutCredentialsThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING);
+		assertThatIllegalArgumentException().isThrownBy(
+				() -> new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING));
 	}
 
 	@Test
 	public void constructorWhenRelyingPartyWithoutPrivateKeyThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING);
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING));
 	}
 
 	@Test
 	public void constructorWhenRelyingPartyWithoutCertificateThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING);
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING));
 	}
 
 	@Test
 	public void constructorWhenAssertingPartyWithoutCertificateThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING);
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING));
 	}
 
 	@Test
 	public void constructorWhenRelyingPartyWithEncryptionUsageThenItFails() {
-		this.exception.expect(IllegalStateException.class);
-		new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION);
+		assertThatIllegalStateException().isThrownBy(
+				() -> new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION));
 	}
 
 	@Test
 	public void constructorWhenRelyingPartyWithVerificationUsageThenItFails() {
-		this.exception.expect(IllegalStateException.class);
-		new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION);
+		assertThatIllegalStateException().isThrownBy(
+				() -> new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION));
 	}
 
 	@Test
 	public void constructorWhenAssertingPartyWithSigningUsageThenItFails() {
-		this.exception.expect(IllegalStateException.class);
-		new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING);
+		assertThatIllegalStateException()
+				.isThrownBy(() -> new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING));
 	}
 
 	@Test
 	public void constructorWhenAssertingPartyWithDecryptionUsageThenItFails() {
-		this.exception.expect(IllegalStateException.class);
-		new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION);
+		assertThatIllegalStateException()
+				.isThrownBy(() -> new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION));
 	}
 
 	@Test
 	public void factoryWhenRelyingPartyForSigningWithoutCredentialsThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		Saml2X509Credential.signing(null, null);
+		assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.signing(null, null));
 	}
 
 	@Test
 	public void factoryWhenRelyingPartyForSigningWithoutPrivateKeyThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		Saml2X509Credential.signing(null, this.certificate);
+		assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.signing(null, this.certificate));
 	}
 
 	@Test
 	public void factoryWhenRelyingPartyForSigningWithoutCertificateThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		Saml2X509Credential.signing(this.key, null);
+		assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.signing(this.key, null));
 	}
 
 	@Test
 	public void factoryWhenRelyingPartyForDecryptionWithoutCredentialsThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		Saml2X509Credential.decryption(null, null);
+		assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.decryption(null, null));
 	}
 
 	@Test
 	public void factoryWhenRelyingPartyForDecryptionWithoutPrivateKeyThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		Saml2X509Credential.decryption(null, this.certificate);
+		assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.decryption(null, this.certificate));
 	}
 
 	@Test
 	public void factoryWhenRelyingPartyForDecryptionWithoutCertificateThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		Saml2X509Credential.decryption(this.key, null);
+		assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.decryption(this.key, null));
 	}
 
 	@Test
 	public void factoryWhenAssertingPartyForVerificationWithoutCertificateThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		Saml2X509Credential.verification(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.verification(null));
 	}
 
 	@Test
 	public void factoryWhenAssertingPartyForEncryptionWithoutCertificateThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		Saml2X509Credential.encryption(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.encryption(null));
 	}
 
 }

+ 19 - 21
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/Saml2X509CredentialTests.java

@@ -23,17 +23,15 @@ import java.security.cert.CertificateFactory;
 import java.security.cert.X509Certificate;
 
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import org.springframework.security.converter.RsaKeyConverters;
 import org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType;
 
-public class Saml2X509CredentialTests {
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
 
-	@Rule
-	public ExpectedException exception = ExpectedException.none();
+public class Saml2X509CredentialTests {
 
 	private Saml2X509Credential credential;
 
@@ -97,50 +95,50 @@ public class Saml2X509CredentialTests {
 
 	@Test
 	public void constructorWhenRelyingPartyWithoutCredentialsThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING);
+		assertThatIllegalArgumentException().isThrownBy(
+				() -> new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING));
 	}
 
 	@Test
 	public void constructorWhenRelyingPartyWithoutPrivateKeyThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING);
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING));
 	}
 
 	@Test
 	public void constructorWhenRelyingPartyWithoutCertificateThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING);
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING));
 	}
 
 	@Test
 	public void constructorWhenAssertingPartyWithoutCertificateThenItFails() {
-		this.exception.expect(IllegalArgumentException.class);
-		new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING);
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING));
 	}
 
 	@Test
 	public void constructorWhenRelyingPartyWithEncryptionUsageThenItFails() {
-		this.exception.expect(IllegalStateException.class);
-		new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION);
+		assertThatIllegalStateException().isThrownBy(
+				() -> new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION));
 	}
 
 	@Test
 	public void constructorWhenRelyingPartyWithVerificationUsageThenItFails() {
-		this.exception.expect(IllegalStateException.class);
-		new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION);
+		assertThatIllegalStateException().isThrownBy(
+				() -> new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION));
 	}
 
 	@Test
 	public void constructorWhenAssertingPartyWithSigningUsageThenItFails() {
-		this.exception.expect(IllegalStateException.class);
-		new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING);
+		assertThatIllegalStateException()
+				.isThrownBy(() -> new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING));
 	}
 
 	@Test
 	public void constructorWhenAssertingPartyWithDecryptionUsageThenItFails() {
-		this.exception.expect(IllegalStateException.class);
-		new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION);
+		assertThatIllegalStateException()
+				.isThrownBy(() -> new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION));
 	}
 
 }

+ 43 - 60
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

@@ -26,18 +26,14 @@ import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.function.Consumer;
 
 import javax.xml.namespace.QName;
 
 import net.shibboleth.utilities.java.support.xml.SerializeSupport;
-import org.hamcrest.BaseMatcher;
-import org.hamcrest.Description;
-import org.hamcrest.Matcher;
 import org.joda.time.DateTime;
 import org.joda.time.Duration;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 import org.opensaml.core.xml.XMLObject;
 import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
 import org.opensaml.core.xml.io.Marshaller;
@@ -93,9 +89,6 @@ public class OpenSamlAuthenticationProviderTests {
 	private Saml2Authentication authentication = new Saml2Authentication(this.principal, "response",
 			Collections.emptyList());
 
-	@Rule
-	public ExpectedException exception = ExpectedException.none();
-
 	@Test
 	public void supportsWhenSaml2AuthenticationTokenThenReturnTrue() {
 		assertThat(this.provider.supports(Saml2AuthenticationToken.class))
@@ -113,53 +106,56 @@ public class OpenSamlAuthenticationProviderTests {
 
 	@Test
 	public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() {
-		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
 		Assertion assertion = (Assertion) XMLObjectProviderRegistrySupport.getBuilderFactory()
 				.getBuilder(Assertion.DEFAULT_ELEMENT_NAME).buildObject(Assertion.DEFAULT_ELEMENT_NAME);
-		this.provider
-				.authenticate(token(serialize(assertion), TestSaml2X509Credentials.relyingPartyVerifyingCredential()));
+		assertThatExceptionOfType(Saml2AuthenticationException.class)
+				.isThrownBy(() -> this.provider.authenticate(
+						token(serialize(assertion), TestSaml2X509Credentials.relyingPartyVerifyingCredential())))
+				.satisfies(errorOf(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
 	}
 
 	@Test
 	public void authenticateWhenXmlErrorThenThrowAuthenticationException() {
-		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
 		Saml2AuthenticationToken token = token("invalid xml",
 				TestSaml2X509Credentials.relyingPartyVerifyingCredential());
-		this.provider.authenticate(token);
+		assertThatExceptionOfType(Saml2AuthenticationException.class)
+				.isThrownBy(() -> this.provider.authenticate(token))
+				.satisfies(errorOf(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
 	}
 
 	@Test
 	public void authenticateWhenInvalidDestinationThenThrowAuthenticationException() {
-		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_DESTINATION));
 		Response response = TestOpenSamlObjects.response(DESTINATION + "invalid", ASSERTING_PARTY_ENTITY_ID);
 		response.getAssertions().add(TestOpenSamlObjects.assertion());
 		TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
 				RELYING_PARTY_ENTITY_ID);
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
-		this.provider.authenticate(token);
+		assertThatExceptionOfType(Saml2AuthenticationException.class)
+				.isThrownBy(() -> this.provider.authenticate(token))
+				.satisfies(errorOf(Saml2ErrorCodes.INVALID_DESTINATION));
 	}
 
 	@Test
 	public void authenticateWhenNoAssertionsPresentThenThrowAuthenticationException() {
-		this.exception.expect(
-				authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response."));
 		Saml2AuthenticationToken token = token(TestOpenSamlObjects.response(),
 				TestSaml2X509Credentials.assertingPartySigningCredential());
-		this.provider.authenticate(token);
+		assertThatExceptionOfType(Saml2AuthenticationException.class)
+				.isThrownBy(() -> this.provider.authenticate(token))
+				.satisfies(errorOf(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response."));
 	}
 
 	@Test
 	public void authenticateWhenInvalidSignatureOnAssertionThenThrowAuthenticationException() {
-		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE));
 		Response response = TestOpenSamlObjects.response();
 		response.getAssertions().add(TestOpenSamlObjects.assertion());
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
-		this.provider.authenticate(token);
+		assertThatExceptionOfType(Saml2AuthenticationException.class)
+				.isThrownBy(() -> this.provider.authenticate(token))
+				.satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE));
 	}
 
 	@Test
 	public void authenticateWhenOpenSAMLValidationErrorThenThrowAuthenticationException() throws Exception {
-		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_ASSERTION));
 		Response response = TestOpenSamlObjects.response();
 		Assertion assertion = TestOpenSamlObjects.assertion();
 		assertion.getSubject().getSubjectConfirmations().get(0).getSubjectConfirmationData()
@@ -168,12 +164,13 @@ public class OpenSamlAuthenticationProviderTests {
 				RELYING_PARTY_ENTITY_ID);
 		response.getAssertions().add(assertion);
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
-		this.provider.authenticate(token);
+		assertThatExceptionOfType(Saml2AuthenticationException.class)
+				.isThrownBy(() -> this.provider.authenticate(token))
+				.satisfies(errorOf(Saml2ErrorCodes.INVALID_ASSERTION));
 	}
 
 	@Test
 	public void authenticateWhenMissingSubjectThenThrowAuthenticationException() {
-		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.SUBJECT_NOT_FOUND));
 		Response response = TestOpenSamlObjects.response();
 		Assertion assertion = TestOpenSamlObjects.assertion();
 		assertion.setSubject(null);
@@ -181,12 +178,13 @@ public class OpenSamlAuthenticationProviderTests {
 				RELYING_PARTY_ENTITY_ID);
 		response.getAssertions().add(assertion);
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
-		this.provider.authenticate(token);
+		assertThatExceptionOfType(Saml2AuthenticationException.class)
+				.isThrownBy(() -> this.provider.authenticate(token))
+				.satisfies(errorOf(Saml2ErrorCodes.SUBJECT_NOT_FOUND));
 	}
 
 	@Test
 	public void authenticateWhenUsernameMissingThenThrowAuthenticationException() throws Exception {
-		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.SUBJECT_NOT_FOUND));
 		Response response = TestOpenSamlObjects.response();
 		Assertion assertion = TestOpenSamlObjects.assertion();
 		assertion.getSubject().getNameID().setValue(null);
@@ -194,7 +192,9 @@ public class OpenSamlAuthenticationProviderTests {
 				RELYING_PARTY_ENTITY_ID);
 		response.getAssertions().add(assertion);
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
-		this.provider.authenticate(token);
+		assertThatExceptionOfType(Saml2AuthenticationException.class)
+				.isThrownBy(() -> this.provider.authenticate(token))
+				.satisfies(errorOf(Saml2ErrorCodes.SUBJECT_NOT_FOUND));
 	}
 
 	@Test
@@ -236,13 +236,14 @@ public class OpenSamlAuthenticationProviderTests {
 
 	@Test
 	public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() throws Exception {
-		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE));
 		Response response = TestOpenSamlObjects.response();
 		EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
 				TestSaml2X509Credentials.assertingPartyEncryptingCredential());
 		response.getEncryptedAssertions().add(encryptedAssertion);
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyDecryptingCredential());
-		this.provider.authenticate(token);
+		assertThatExceptionOfType(Saml2AuthenticationException.class)
+				.isThrownBy(() -> this.provider.authenticate(token))
+				.satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE));
 	}
 
 	@Test
@@ -290,28 +291,28 @@ public class OpenSamlAuthenticationProviderTests {
 
 	@Test
 	public void authenticateWhenDecryptionKeysAreMissingThenThrowAuthenticationException() throws Exception {
-		this.exception
-				.expect(authenticationMatcher(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData"));
 		Response response = TestOpenSamlObjects.response();
 		EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
 				TestSaml2X509Credentials.assertingPartyEncryptingCredential());
 		response.getEncryptedAssertions().add(encryptedAssertion);
 		Saml2AuthenticationToken token = token(serialize(response),
 				TestSaml2X509Credentials.relyingPartyVerifyingCredential());
-		this.provider.authenticate(token);
+		assertThatExceptionOfType(Saml2AuthenticationException.class)
+				.isThrownBy(() -> this.provider.authenticate(token))
+				.satisfies(errorOf(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData"));
 	}
 
 	@Test
 	public void authenticateWhenDecryptionKeysAreWrongThenThrowAuthenticationException() throws Exception {
-		this.exception
-				.expect(authenticationMatcher(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData"));
 		Response response = TestOpenSamlObjects.response();
 		EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
 				TestSaml2X509Credentials.assertingPartyEncryptingCredential());
 		response.getEncryptedAssertions().add(encryptedAssertion);
 		Saml2AuthenticationToken token = token(serialize(response),
 				TestSaml2X509Credentials.assertingPartyPrivateCredential());
-		this.provider.authenticate(token);
+		assertThatExceptionOfType(Saml2AuthenticationException.class)
+				.isThrownBy(() -> this.provider.authenticate(token))
+				.satisfies(errorOf(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData"));
 	}
 
 	@Test
@@ -487,33 +488,15 @@ public class OpenSamlAuthenticationProviderTests {
 		}
 	}
 
-	private Matcher<Saml2AuthenticationException> authenticationMatcher(String code) {
-		return authenticationMatcher(code, null);
-	}
-
-	private Matcher<Saml2AuthenticationException> authenticationMatcher(String code, String description) {
-		return new BaseMatcher<Saml2AuthenticationException>() {
-			@Override
-			public boolean matches(Object item) {
-				if (!(item instanceof Saml2AuthenticationException)) {
-					return false;
-				}
-				Saml2AuthenticationException ex = (Saml2AuthenticationException) item;
-				if (!code.equals(ex.getError().getErrorCode())) {
-					return false;
-				}
-				if (StringUtils.hasText(description)) {
-					if (!description.equals(ex.getError().getDescription())) {
-						return false;
-					}
-				}
-				return true;
-			}
+	private Consumer<Saml2AuthenticationException> errorOf(String errorCode) {
+		return errorOf(errorCode, null);
+	}
 
-			@Override
-			public void describeTo(Description desc) {
-				String excepting = "Saml2AuthenticationException[code=" + code + "; description=" + description + "]";
-				desc.appendText(excepting);
+	private Consumer<Saml2AuthenticationException> errorOf(String errorCode, String description) {
+		return (ex) -> {
+			assertThat(ex.getError().getErrorCode()).isEqualTo(errorCode);
+			if (StringUtils.hasText(description)) {
+				assertThat(ex.getError().getDescription()).isEqualTo(description);
 			}
 		};
 	}

+ 2 - 9
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java

@@ -21,9 +21,7 @@ import java.nio.charset.StandardCharsets;
 
 import org.junit.Assert;
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
 import org.opensaml.saml.common.xml.SAMLConstants;
 import org.opensaml.saml.saml2.core.AuthnRequest;
@@ -39,7 +37,6 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
-import static org.hamcrest.CoreMatchers.containsString;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -61,9 +58,6 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 
 	private AuthnRequestUnmarshaller unmarshaller;
 
-	@Rule
-	public ExpectedException exception = ExpectedException.none();
-
 	@Before
 	public void setUp() {
 		this.relyingPartyRegistrationBuilder = RelyingPartyRegistration.withRegistrationId("id")
@@ -160,9 +154,8 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 
 	@Test
 	public void createAuthenticationRequestWhenSetUnsupportredUriThenThrowsIllegalArgumentException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.exception.expectMessage(containsString("my-invalid-binding"));
-		this.factory.setProtocolBinding("my-invalid-binding");
+		assertThatIllegalArgumentException().isThrownBy(() -> this.factory.setProtocolBinding("my-invalid-binding"))
+				.withMessageContaining("my-invalid-binding");
 	}
 
 	@Test

+ 3 - 8
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java

@@ -20,9 +20,7 @@ import javax.servlet.http.HttpServletResponse;
 
 import org.junit.Assert;
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
@@ -43,9 +41,6 @@ public class Saml2WebSsoAuthenticationFilterTests {
 
 	private HttpServletResponse response = new MockHttpServletResponse();
 
-	@Rule
-	public ExpectedException exception = ExpectedException.none();
-
 	@Before
 	public void setup() {
 		this.filter = new Saml2WebSsoAuthenticationFilter(this.repository);
@@ -55,9 +50,9 @@ public class Saml2WebSsoAuthenticationFilterTests {
 
 	@Test
 	public void constructingFilterWithMissingRegistrationIdVariableThenThrowsException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.exception.expectMessage("filterProcessesUrl must contain a {registrationId} match variable");
-		this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/url/missing/variable");
+		assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(
+				() -> this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/url/missing/variable"))
+				.withMessage("filterProcessesUrl must contain a {registrationId} match variable");
 	}
 
 	@Test

+ 10 - 14
web/src/test/java/org/springframework/security/web/authentication/DelegatingAuthenticationFailureHandlerTests.java

@@ -22,9 +22,7 @@ import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
@@ -35,6 +33,7 @@ import org.springframework.security.authentication.BadCredentialsException;
 import org.springframework.security.authentication.CredentialsExpiredException;
 import org.springframework.security.core.AuthenticationException;
 
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyZeroInteractions;
 
@@ -48,9 +47,6 @@ import static org.mockito.Mockito.verifyZeroInteractions;
 @RunWith(MockitoJUnitRunner.class)
 public class DelegatingAuthenticationFailureHandlerTests {
 
-	@Rule
-	public ExpectedException thrown = ExpectedException.none();
-
 	@Mock
 	private AuthenticationFailureHandler handler1;
 
@@ -110,24 +106,24 @@ public class DelegatingAuthenticationFailureHandlerTests {
 
 	@Test
 	public void handlersIsNull() {
-		this.thrown.expect(IllegalArgumentException.class);
-		this.thrown.expectMessage("handlers cannot be null or empty");
-		new DelegatingAuthenticationFailureHandler(null, this.defaultHandler);
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new DelegatingAuthenticationFailureHandler(null, this.defaultHandler))
+				.withMessage("handlers cannot be null or empty");
 	}
 
 	@Test
 	public void handlersIsEmpty() {
-		this.thrown.expect(IllegalArgumentException.class);
-		this.thrown.expectMessage("handlers cannot be null or empty");
-		new DelegatingAuthenticationFailureHandler(this.handlers, this.defaultHandler);
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new DelegatingAuthenticationFailureHandler(this.handlers, this.defaultHandler))
+				.withMessage("handlers cannot be null or empty");
 	}
 
 	@Test
 	public void defaultHandlerIsNull() {
-		this.thrown.expect(IllegalArgumentException.class);
-		this.thrown.expectMessage("defaultHandler cannot be null");
 		this.handlers.put(BadCredentialsException.class, this.handler1);
-		new DelegatingAuthenticationFailureHandler(this.handlers, null);
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new DelegatingAuthenticationFailureHandler(this.handlers, null))
+				.withMessage("defaultHandler cannot be null");
 	}
 
 }

+ 2 - 8
web/src/test/java/org/springframework/security/web/authentication/logout/CompositeLogoutHandlerTests.java

@@ -22,9 +22,7 @@ import java.util.List;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 import org.mockito.InOrder;
 
 import org.springframework.security.core.Authentication;
@@ -45,14 +43,10 @@ import static org.mockito.Mockito.verify;
  */
 public class CompositeLogoutHandlerTests {
 
-	@Rule
-	public ExpectedException exception = ExpectedException.none();
-
 	@Test
 	public void buildEmptyCompositeLogoutHandlerThrowsException() {
-		this.exception.expect(IllegalArgumentException.class);
-		this.exception.expectMessage("LogoutHandlers are required");
-		new CompositeLogoutHandler();
+		assertThatIllegalArgumentException().isThrownBy(() -> new CompositeLogoutHandler())
+				.withMessage("LogoutHandlers are required");
 	}
 
 	@Test

+ 5 - 11
web/src/test/java/org/springframework/security/web/authentication/logout/ForwardLogoutSuccessHandlerTests.java

@@ -16,15 +16,14 @@
 
 package org.springframework.security.web.authentication.logout;
 
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.core.Authentication;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.Mockito.mock;
 
 /**
@@ -34,23 +33,18 @@ import static org.mockito.Mockito.mock;
  */
 public class ForwardLogoutSuccessHandlerTests {
 
-	@Rule
-	public ExpectedException thrown = ExpectedException.none();
-
 	@Test
 	public void invalidTargetUrl() {
 		String targetUrl = "not.valid";
-		this.thrown.expect(IllegalArgumentException.class);
-		this.thrown.expectMessage("'" + targetUrl + "' is not a valid target URL");
-		new ForwardLogoutSuccessHandler(targetUrl);
+		assertThatIllegalArgumentException().isThrownBy(() -> new ForwardLogoutSuccessHandler(targetUrl))
+				.withMessage("'" + targetUrl + "' is not a valid target URL");
 	}
 
 	@Test
 	public void emptyTargetUrl() {
 		String targetUrl = " ";
-		this.thrown.expect(IllegalArgumentException.class);
-		this.thrown.expectMessage("'" + targetUrl + "' is not a valid target URL");
-		new ForwardLogoutSuccessHandler(targetUrl);
+		assertThatIllegalArgumentException().isThrownBy(() -> new ForwardLogoutSuccessHandler(targetUrl))
+				.withMessage("'" + targetUrl + "' is not a valid target URL");
 	}
 
 	@Test

+ 3 - 8
web/src/test/java/org/springframework/security/web/authentication/logout/HeaderWriterLogoutHandlerTests.java

@@ -17,15 +17,14 @@
 package org.springframework.security.web.authentication.logout;
 
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.web.header.HeaderWriter;
 
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 
@@ -40,9 +39,6 @@ public class HeaderWriterLogoutHandlerTests {
 
 	private MockHttpServletRequest request;
 
-	@Rule
-	public ExpectedException thrown = ExpectedException.none();
-
 	@Before
 	public void setup() {
 		this.response = new MockHttpServletResponse();
@@ -51,9 +47,8 @@ public class HeaderWriterLogoutHandlerTests {
 
 	@Test
 	public void constructorWhenHeaderWriterIsNullThenThrowsException() {
-		this.thrown.expect(IllegalArgumentException.class);
-		this.thrown.expectMessage("headerWriter cannot be null");
-		new HeaderWriterLogoutHandler(null);
+		assertThatIllegalArgumentException().isThrownBy(() -> new HeaderWriterLogoutHandler(null))
+				.withMessage("headerWriter cannot be null");
 	}
 
 	@Test

+ 3 - 8
web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java

@@ -23,9 +23,7 @@ import javax.servlet.FilterChain;
 
 import org.junit.After;
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
@@ -49,6 +47,7 @@ import org.springframework.security.web.authentication.SimpleUrlAuthenticationSu
 import org.springframework.security.web.util.matcher.AnyRequestMatcher;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
@@ -64,9 +63,6 @@ public class SwitchUserFilterTests {
 
 	private static final List<GrantedAuthority> ROLES_12 = AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO");
 
-	@Rule
-	public ExpectedException thrown = ExpectedException.none();
-
 	@Before
 	public void authenticateCurrentUser() {
 		UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken("dano", "hawaii50");
@@ -437,9 +433,8 @@ public class SwitchUserFilterTests {
 	// gh-3697
 	@Test
 	public void switchAuthorityRoleCannotBeNull() {
-		this.thrown.expect(IllegalArgumentException.class);
-		this.thrown.expectMessage("switchAuthorityRole cannot be null");
-		switchToUserWithAuthorityRole("dano", null);
+		assertThatIllegalArgumentException().isThrownBy(() -> switchToUserWithAuthorityRole("dano", null))
+				.withMessage("switchAuthorityRole cannot be null");
 	}
 
 	// gh-3697

+ 19 - 25
web/src/test/java/org/springframework/security/web/firewall/FirewalledResponseTests.java

@@ -20,9 +20,7 @@ import javax.servlet.http.Cookie;
 import javax.servlet.http.HttpServletResponse;
 
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.Mockito.mock;
@@ -35,8 +33,7 @@ import static org.mockito.Mockito.verify;
  */
 public class FirewalledResponseTests {
 
-	@Rule
-	public ExpectedException expectedException = ExpectedException.none();
+	private static final String CRLF_MESSAGE = "Invalid characters (CR/LF)";
 
 	private HttpServletResponse response;
 
@@ -62,8 +59,8 @@ public class FirewalledResponseTests {
 
 	@Test
 	public void sendRedirectWhenHasCrlfThenThrowsException() throws Exception {
-		expectCrlfValidationException();
-		this.fwResponse.sendRedirect("/theURL\r\nsomething");
+		assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.sendRedirect("/theURL\r\nsomething"))
+				.withMessageContaining(CRLF_MESSAGE);
 	}
 
 	@Test
@@ -80,14 +77,16 @@ public class FirewalledResponseTests {
 
 	@Test
 	public void addHeaderWhenHeaderValueHasCrlfThenException() {
-		expectCrlfValidationException();
-		this.fwResponse.addHeader("foo", "abc\r\nContent-Length:100");
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.fwResponse.addHeader("foo", "abc\r\nContent-Length:100"))
+				.withMessageContaining(CRLF_MESSAGE);
 	}
 
 	@Test
 	public void addHeaderWhenHeaderNameHasCrlfThenException() {
-		expectCrlfValidationException();
-		this.fwResponse.addHeader("abc\r\nContent-Length:100", "bar");
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.fwResponse.addHeader("abc\r\nContent-Length:100", "bar"))
+				.withMessageContaining(CRLF_MESSAGE);
 	}
 
 	@Test
@@ -115,39 +114,39 @@ public class FirewalledResponseTests {
 				return "foo\r\nbar";
 			}
 		};
-		expectCrlfValidationException();
-		this.fwResponse.addCookie(cookie);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie))
+				.withMessageContaining(CRLF_MESSAGE);
 	}
 
 	@Test
 	public void addCookieWhenCookieValueContainsCrlfThenException() {
 		Cookie cookie = new Cookie("foo", "foo\r\nbar");
-		expectCrlfValidationException();
-		this.fwResponse.addCookie(cookie);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie))
+				.withMessageContaining(CRLF_MESSAGE);
 	}
 
 	@Test
 	public void addCookieWhenCookiePathContainsCrlfThenException() {
 		Cookie cookie = new Cookie("foo", "bar");
 		cookie.setPath("/foo\r\nbar");
-		expectCrlfValidationException();
-		this.fwResponse.addCookie(cookie);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie))
+				.withMessageContaining(CRLF_MESSAGE);
 	}
 
 	@Test
 	public void addCookieWhenCookieDomainContainsCrlfThenException() {
 		Cookie cookie = new Cookie("foo", "bar");
 		cookie.setDomain("foo\r\nbar");
-		expectCrlfValidationException();
-		this.fwResponse.addCookie(cookie);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie))
+				.withMessageContaining(CRLF_MESSAGE);
 	}
 
 	@Test
 	public void addCookieWhenCookieCommentContainsCrlfThenException() {
 		Cookie cookie = new Cookie("foo", "bar");
 		cookie.setComment("foo\r\nbar");
-		expectCrlfValidationException();
-		this.fwResponse.addCookie(cookie);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie))
+				.withMessageContaining(CRLF_MESSAGE);
 	}
 
 	@Test
@@ -160,11 +159,6 @@ public class FirewalledResponseTests {
 		validateLineEnding("foo\nbar", "bar");
 	}
 
-	private void expectCrlfValidationException() {
-		this.expectedException.expect(IllegalArgumentException.class);
-		this.expectedException.expectMessage("Invalid characters (CR/LF)");
-	}
-
 	private void validateLineEnding(String name, String value) {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.validateCrlf(name, value));
 	}

+ 3 - 8
web/src/test/java/org/springframework/security/web/header/writers/ClearSiteDataHeaderWriterTests.java

@@ -17,15 +17,14 @@
 package org.springframework.security.web.header.writers;
 
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 
 /**
  * @author Rafiullah Hamedy
@@ -40,9 +39,6 @@ public class ClearSiteDataHeaderWriterTests {
 
 	private MockHttpServletResponse response;
 
-	@Rule
-	public ExpectedException thrown = ExpectedException.none();
-
 	@Before
 	public void setup() {
 		this.request = new MockHttpServletRequest();
@@ -52,9 +48,8 @@ public class ClearSiteDataHeaderWriterTests {
 
 	@Test
 	public void createInstanceWhenMissingSourceThenThrowsException() {
-		this.thrown.expect(Exception.class);
-		this.thrown.expectMessage("directives cannot be empty or null");
-		new ClearSiteDataHeaderWriter();
+		assertThatExceptionOfType(Exception.class).isThrownBy(() -> new ClearSiteDataHeaderWriter())
+				.withMessage("directives cannot be empty or null");
 	}
 
 	@Test

+ 51 - 63
web/src/test/java/org/springframework/security/web/server/authentication/SwitchUserWebFilterTests.java

@@ -20,9 +20,7 @@ import java.security.Principal;
 import java.util.Collections;
 
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
@@ -55,7 +53,8 @@ import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.web.server.WebFilterChain;
 
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.fail;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
@@ -83,9 +82,6 @@ public class SwitchUserWebFilterTests {
 	@Mock
 	private ServerSecurityContextRepository serverSecurityContextRepository;
 
-	@Rule
-	public ExpectedException exceptionRule = ExpectedException.none();
-
 	@Before
 	public void setUp() {
 		this.switchUserWebFilter = new SwitchUserWebFilter(this.userDetailsService, this.successHandler,
@@ -183,11 +179,12 @@ public class SwitchUserWebFilterTests {
 				.from(MockServerHttpRequest.post("/login/impersonate"));
 		final WebFilterChain chain = mock(WebFilterChain.class);
 		final SecurityContextImpl securityContext = new SecurityContextImpl(mock(Authentication.class));
-		this.exceptionRule.expect(IllegalArgumentException.class);
-		this.exceptionRule.expectMessage("The userName can not be null.");
-		this.switchUserWebFilter.filter(exchange, chain)
-				.subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
-				.block();
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.switchUserWebFilter.filter(exchange, chain)
+						.subscriberContext(
+								ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
+						.block())
+				.withMessage("The userName can not be null.");
 		verifyNoInteractions(chain);
 	}
 
@@ -219,10 +216,12 @@ public class SwitchUserWebFilterTests {
 		final SecurityContextImpl securityContext = new SecurityContextImpl(mock(Authentication.class));
 		final UserDetails switchUserDetails = switchUserDetails(targetUsername, false);
 		given(this.userDetailsService.findByUsername(any(String.class))).willReturn(Mono.just(switchUserDetails));
-		this.exceptionRule.expect(DisabledException.class);
-		this.switchUserWebFilter.filter(exchange, chain)
-				.subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
-				.block();
+		assertThatExceptionOfType(DisabledException.class)
+				.isThrownBy(
+						() -> this.switchUserWebFilter.filter(exchange, chain)
+								.subscriberContext(
+										ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
+								.block());
 		verifyNoInteractions(chain);
 	}
 
@@ -265,11 +264,12 @@ public class SwitchUserWebFilterTests {
 				"origCredentials");
 		final WebFilterChain chain = mock(WebFilterChain.class);
 		final SecurityContextImpl securityContext = new SecurityContextImpl(originalAuthentication);
-		this.exceptionRule.expect(AuthenticationCredentialsNotFoundException.class);
-		this.exceptionRule.expectMessage("Could not find original Authentication object");
-		this.switchUserWebFilter.filter(exchange, chain)
-				.subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
-				.block();
+		assertThatExceptionOfType(AuthenticationCredentialsNotFoundException.class)
+				.isThrownBy(() -> this.switchUserWebFilter.filter(exchange, chain)
+						.subscriberContext(
+								ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
+						.block())
+				.withMessage("Could not find original Authentication object");
 		verifyNoInteractions(chain);
 	}
 
@@ -278,34 +278,35 @@ public class SwitchUserWebFilterTests {
 		final MockServerWebExchange exchange = MockServerWebExchange
 				.from(MockServerHttpRequest.post("/logout/impersonate"));
 		final WebFilterChain chain = mock(WebFilterChain.class);
-		this.exceptionRule.expect(AuthenticationCredentialsNotFoundException.class);
-		this.exceptionRule.expectMessage("No current user associated with this request");
-		this.switchUserWebFilter.filter(exchange, chain).block();
+		assertThatExceptionOfType(AuthenticationCredentialsNotFoundException.class)
+				.isThrownBy(() -> this.switchUserWebFilter.filter(exchange, chain).block())
+				.withMessage("No current user associated with this request");
 		verifyNoInteractions(chain);
 	}
 
 	@Test
 	public void constructorUserDetailsServiceRequired() {
-		this.exceptionRule.expect(IllegalArgumentException.class);
-		this.exceptionRule.expectMessage("userDetailsService must be specified");
-		this.switchUserWebFilter = new SwitchUserWebFilter(null, mock(ServerAuthenticationSuccessHandler.class),
-				mock(ServerAuthenticationFailureHandler.class));
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.switchUserWebFilter = new SwitchUserWebFilter(null,
+						mock(ServerAuthenticationSuccessHandler.class), mock(ServerAuthenticationFailureHandler.class)))
+				.withMessage("userDetailsService must be specified");
 	}
 
 	@Test
 	public void constructorServerAuthenticationSuccessHandlerRequired() {
-		this.exceptionRule.expect(IllegalArgumentException.class);
-		this.exceptionRule.expectMessage("successHandler must be specified");
-		this.switchUserWebFilter = new SwitchUserWebFilter(mock(ReactiveUserDetailsService.class), null,
-				mock(ServerAuthenticationFailureHandler.class));
+		assertThatIllegalArgumentException()
+				.isThrownBy(
+						() -> this.switchUserWebFilter = new SwitchUserWebFilter(mock(ReactiveUserDetailsService.class),
+								null, mock(ServerAuthenticationFailureHandler.class)))
+				.withMessage("successHandler must be specified");
 	}
 
 	@Test
 	public void constructorSuccessTargetUrlRequired() {
-		this.exceptionRule.expect(IllegalArgumentException.class);
-		this.exceptionRule.expectMessage("successTargetUrl must be specified");
-		this.switchUserWebFilter = new SwitchUserWebFilter(mock(ReactiveUserDetailsService.class), null,
-				"failure/target/url");
+		assertThatIllegalArgumentException().isThrownBy(
+				() -> this.switchUserWebFilter = new SwitchUserWebFilter(mock(ReactiveUserDetailsService.class), null,
+						"failure/target/url"))
+				.withMessage("successTargetUrl must be specified");
 	}
 
 	@Test
@@ -336,10 +337,9 @@ public class SwitchUserWebFilterTests {
 
 	@Test
 	public void setSecurityContextRepositoryWhenNullThenThrowException() {
-		this.exceptionRule.expect(IllegalArgumentException.class);
-		this.exceptionRule.expectMessage("securityContextRepository cannot be null");
-		this.switchUserWebFilter.setSecurityContextRepository(null);
-		fail("Test should fail with exception");
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.switchUserWebFilter.setSecurityContextRepository(null))
+				.withMessage("securityContextRepository cannot be null");
 	}
 
 	@Test
@@ -357,18 +357,14 @@ public class SwitchUserWebFilterTests {
 
 	@Test
 	public void setExitUserUrlWhenNullThenThrowException() {
-		this.exceptionRule.expect(IllegalArgumentException.class);
-		this.exceptionRule.expectMessage("exitUserUrl cannot be empty and must be a valid redirect URL");
-		this.switchUserWebFilter.setExitUserUrl(null);
-		fail("Test should fail with exception");
+		assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setExitUserUrl(null))
+				.withMessage("exitUserUrl cannot be empty and must be a valid redirect URL");
 	}
 
 	@Test
 	public void setExitUserUrlWhenInvalidUrlThenThrowException() {
-		this.exceptionRule.expect(IllegalArgumentException.class);
-		this.exceptionRule.expectMessage("exitUserUrl cannot be empty and must be a valid redirect URL");
-		this.switchUserWebFilter.setExitUserUrl("wrongUrl");
-		fail("Test should fail with exception");
+		assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setExitUserUrl("wrongUrl"))
+				.withMessage("exitUserUrl cannot be empty and must be a valid redirect URL");
 	}
 
 	@Test
@@ -387,10 +383,8 @@ public class SwitchUserWebFilterTests {
 
 	@Test
 	public void setExitUserMatcherWhenNullThenThrowException() {
-		this.exceptionRule.expect(IllegalArgumentException.class);
-		this.exceptionRule.expectMessage("exitUserMatcher cannot be null");
-		this.switchUserWebFilter.setExitUserMatcher(null);
-		fail("Test should fail with exception");
+		assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setExitUserMatcher(null))
+				.withMessage("exitUserMatcher cannot be null");
 	}
 
 	@Test
@@ -410,18 +404,14 @@ public class SwitchUserWebFilterTests {
 
 	@Test
 	public void setSwitchUserUrlWhenNullThenThrowException() {
-		this.exceptionRule.expect(IllegalArgumentException.class);
-		this.exceptionRule.expectMessage("switchUserUrl cannot be empty and must be a valid redirect URL");
-		this.switchUserWebFilter.setSwitchUserUrl(null);
-		fail("Test should fail with exception");
+		assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setSwitchUserUrl(null))
+				.withMessage("switchUserUrl cannot be empty and must be a valid redirect URL");
 	}
 
 	@Test
 	public void setSwitchUserUrlWhenInvalidThenThrowException() {
-		this.exceptionRule.expect(IllegalArgumentException.class);
-		this.exceptionRule.expectMessage("switchUserUrl cannot be empty and must be a valid redirect URL");
-		this.switchUserWebFilter.setSwitchUserUrl("wrongUrl");
-		fail("Test should fail with exception");
+		assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setSwitchUserUrl("wrongUrl"))
+				.withMessage("switchUserUrl cannot be empty and must be a valid redirect URL");
 	}
 
 	@Test
@@ -440,10 +430,8 @@ public class SwitchUserWebFilterTests {
 
 	@Test
 	public void setSwitchUserMatcherWhenNullThenThrowException() {
-		this.exceptionRule.expect(IllegalArgumentException.class);
-		this.exceptionRule.expectMessage("switchUserMatcher cannot be null");
-		this.switchUserWebFilter.setSwitchUserMatcher(null);
-		fail("Test should fail with exception");
+		assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setSwitchUserMatcher(null))
+				.withMessage("switchUserMatcher cannot be null");
 	}
 
 	@Test