Selaa lähdekoodia

Extract OidcTokenValidator to an OAuth2TokenValidator

Fixes gh-5930
Joe Grandja 6 vuotta sitten
vanhempi
commit
9c0d78da71

+ 10 - 9
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java

@@ -630,15 +630,16 @@ public class OAuth2LoginConfigurerTests {
 		}
 
 		private static JwtDecoder getJwtDecoder() {
-			return token -> {
-				Map<String, Object> claims = new HashMap<>();
-				claims.put(IdTokenClaimNames.SUB, "sub123");
-				claims.put(IdTokenClaimNames.ISS, "http://localhost/iss");
-				claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId", "a", "u", "d"));
-				claims.put(IdTokenClaimNames.AZP, "clientId");
-				return new Jwt("token123", Instant.now(), Instant.now().plusSeconds(3600),
-						Collections.singletonMap("header1", "value1"), claims);
-			};
+			Map<String, Object> claims = new HashMap<>();
+			claims.put(IdTokenClaimNames.SUB, "sub123");
+			claims.put(IdTokenClaimNames.ISS, "http://localhost/iss");
+			claims.put(IdTokenClaimNames.AUD, Arrays.asList("clientId", "a", "u", "d"));
+			claims.put(IdTokenClaimNames.AZP, "clientId");
+			Jwt jwt = new Jwt("token123", Instant.now(), Instant.now().plusSeconds(3600),
+					Collections.singletonMap("header1", "value1"), claims);
+			JwtDecoder jwtDecoder = mock(JwtDecoder.class);
+			when(jwtDecoder.decode(any())).thenReturn(jwt);
+			return jwtDecoder;
 		}
 	}
 

+ 16 - 3
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProvider.java

@@ -27,9 +27,11 @@ import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
 import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
+import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
@@ -40,6 +42,8 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtDecoder;
 import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
+import org.springframework.security.oauth2.jwt.JwtException;
+import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
 import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
@@ -205,9 +209,14 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati
 
 	private OidcIdToken createOidcToken(ClientRegistration clientRegistration, OAuth2AccessTokenResponse accessTokenResponse) {
 		JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
-		Jwt jwt = jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN));
+		Jwt jwt;
+		try {
+			jwt = jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN));
+		} catch (JwtException ex) {
+			OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, ex.getMessage(), null);
+			throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), ex);
+		}
 		OidcIdToken idToken = new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims());
-		OidcTokenValidator.validateIdToken(idToken, clientRegistration);
 		return idToken;
 	}
 
@@ -228,7 +237,11 @@ public class OidcAuthorizationCodeAuthenticationProvider implements Authenticati
 					throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 				}
 				String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
-				return new NimbusJwtDecoder(withJwkSetUri(jwkSetUri).build());
+				NimbusJwtDecoder jwtDecoder = new NimbusJwtDecoder(withJwkSetUri(jwkSetUri).build());
+				OAuth2TokenValidator<Jwt> jwtValidator = new DelegatingOAuth2TokenValidator<>(
+						new JwtTimestampValidator(), new OidcIdTokenValidator(clientRegistration));
+				jwtDecoder.setJwtValidator(jwtValidator);
+				return jwtDecoder;
 			});
 		}
 	}

+ 17 - 4
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManager.java

@@ -26,10 +26,12 @@ import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessT
 import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
+import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
@@ -37,6 +39,9 @@ import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.core.user.OAuth2User;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtException;
+import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
 import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder;
 import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
 import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
@@ -138,7 +143,11 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
 
 			return this.accessTokenResponseClient.getTokenResponse(authzRequest)
 					.flatMap(accessTokenResponse -> authenticationResult(authorizationCodeAuthentication, accessTokenResponse))
-					.onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()));
+					.onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(e.getError(), e.getError().toString()))
+					.onErrorMap(JwtException.class, e -> {
+						OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE, e.getMessage(), null);
+						throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString(), e);
+					});
 		});
 	}
 
@@ -188,8 +197,7 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
 		ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
 		String rawIdToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN);
 		return jwtDecoder.decode(rawIdToken)
-				.map(jwt -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims()))
-				.doOnNext(idToken -> OidcTokenValidator.validateIdToken(idToken, clientRegistration));
+				.map(jwt -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims()));
 	}
 
 	private static class DefaultJwtDecoderFactory implements ReactiveJwtDecoderFactory<ClientRegistration> {
@@ -208,7 +216,12 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements
 					);
 					throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
 				}
-				return new NimbusReactiveJwtDecoder(clientRegistration.getProviderDetails().getJwkSetUri());
+				NimbusReactiveJwtDecoder jwtDecoder = new NimbusReactiveJwtDecoder(
+						clientRegistration.getProviderDetails().getJwkSetUri());
+				OAuth2TokenValidator<Jwt> jwtValidator = new DelegatingOAuth2TokenValidator<>(
+						new JwtTimestampValidator(), new OidcIdTokenValidator(clientRegistration));
+				jwtDecoder.setJwtValidator(jwtValidator);
+				return jwtDecoder;
 			});
 		}
 	}

+ 38 - 23
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcTokenValidator.java → oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java

@@ -13,13 +13,16 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.springframework.security.oauth2.client.oidc.authentication;
 
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
-import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2TokenValidator;
+import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
+import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 
 import java.net.URL;
@@ -27,36 +30,50 @@ import java.time.Instant;
 import java.util.List;
 
 /**
+ * An {@link OAuth2TokenValidator} responsible for
+ * validating the claims in an {@link OidcIdToken ID Token}.
+ *
  * @author Rob Winch
+ * @author Joe Grandja
  * @since 5.1
+ * @see OAuth2TokenValidator
+ * @see Jwt
+ * @see <a target="_blank" href="http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation">ID Token Validation</a>
  */
-final class OidcTokenValidator {
-	private static final String INVALID_ID_TOKEN_ERROR_CODE = "invalid_id_token";
+public final class OidcIdTokenValidator implements OAuth2TokenValidator<Jwt> {
+	private static final OAuth2Error INVALID_ID_TOKEN_ERROR = new OAuth2Error("invalid_id_token");
+	private final ClientRegistration clientRegistration;
 
-	static void validateIdToken(OidcIdToken idToken, ClientRegistration clientRegistration) {
+	public OidcIdTokenValidator(ClientRegistration clientRegistration) {
+		Assert.notNull(clientRegistration, "clientRegistration cannot be null");
+		this.clientRegistration = clientRegistration;
+	}
+
+	@Override
+	public OAuth2TokenValidatorResult validate(Jwt idToken) {
 		// 3.1.3.7  ID Token Validation
 		// http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
 
 		// Validate REQUIRED Claims
 		URL issuer = idToken.getIssuer();
 		if (issuer == null) {
-			throwInvalidIdTokenException();
+			return invalidIdToken();
 		}
 		String subject = idToken.getSubject();
 		if (subject == null) {
-			throwInvalidIdTokenException();
+			return invalidIdToken();
 		}
 		List<String> audience = idToken.getAudience();
 		if (CollectionUtils.isEmpty(audience)) {
-			throwInvalidIdTokenException();
+			return invalidIdToken();
 		}
 		Instant expiresAt = idToken.getExpiresAt();
 		if (expiresAt == null) {
-			throwInvalidIdTokenException();
+			return invalidIdToken();
 		}
 		Instant issuedAt = idToken.getIssuedAt();
 		if (issuedAt == null) {
-			throwInvalidIdTokenException();
+			return invalidIdToken();
 		}
 
 		// 2. The Issuer Identifier for the OpenID Provider (which is typically obtained during Discovery)
@@ -68,21 +85,21 @@ final class OidcTokenValidator {
 		// The aud (audience) Claim MAY contain an array with more than one element.
 		// The ID Token MUST be rejected if the ID Token does not list the Client as a valid audience,
 		// or if it contains additional audiences not trusted by the Client.
-		if (!audience.contains(clientRegistration.getClientId())) {
-			throwInvalidIdTokenException();
+		if (!audience.contains(this.clientRegistration.getClientId())) {
+			return invalidIdToken();
 		}
 
 		// 4. If the ID Token contains multiple audiences,
 		// the Client SHOULD verify that an azp Claim is present.
-		String authorizedParty = idToken.getAuthorizedParty();
+		String authorizedParty = idToken.getClaimAsString(IdTokenClaimNames.AZP);
 		if (audience.size() > 1 && authorizedParty == null) {
-			throwInvalidIdTokenException();
+			return invalidIdToken();
 		}
 
 		// 5. If an azp (authorized party) Claim is present,
 		// the Client SHOULD verify that its client_id is the Claim Value.
-		if (authorizedParty != null && !authorizedParty.equals(clientRegistration.getClientId())) {
-			throwInvalidIdTokenException();
+		if (authorizedParty != null && !authorizedParty.equals(this.clientRegistration.getClientId())) {
+			return invalidIdToken();
 		}
 
 		// 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the Client
@@ -92,7 +109,7 @@ final class OidcTokenValidator {
 		// 9. The current time MUST be before the time represented by the exp Claim.
 		Instant now = Instant.now();
 		if (!now.isBefore(expiresAt)) {
-			throwInvalidIdTokenException();
+			return invalidIdToken();
 		}
 
 		// 10. The iat Claim can be used to reject tokens that were issued too far away from the current time,
@@ -100,7 +117,7 @@ final class OidcTokenValidator {
 		// The acceptable range is Client specific.
 		Instant maxIssuedAt = now.plusSeconds(30);
 		if (issuedAt.isAfter(maxIssuedAt)) {
-			throwInvalidIdTokenException();
+			return invalidIdToken();
 		}
 
 		// 11. If a nonce value was sent in the Authentication Request,
@@ -110,12 +127,10 @@ final class OidcTokenValidator {
 		// The precise method for detecting replay attacks is Client specific.
 		// TODO Depends on gh-4442
 
+		return OAuth2TokenValidatorResult.success();
 	}
 
-	private static void throwInvalidIdTokenException() {
-		OAuth2Error invalidIdTokenError = new OAuth2Error(INVALID_ID_TOKEN_ERROR_CODE);
-		throw new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString());
+	private static OAuth2TokenValidatorResult invalidIdToken() {
+		return OAuth2TokenValidatorResult.failure(INVALID_ID_TOKEN_ERROR);
 	}
-
-	private OidcTokenValidator() {}
 }

+ 13 - 131
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java

@@ -42,6 +42,7 @@ import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtDecoder;
+import org.springframework.security.oauth2.jwt.JwtException;
 
 import java.time.Instant;
 import java.util.Arrays;
@@ -82,7 +83,7 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 
 	@Before
 	@SuppressWarnings("unchecked")
-	public void setUp() throws Exception {
+	public void setUp() {
 		this.clientRegistration = clientRegistration().clientId("client1").build();
 		this.authorizationRequest = request().scope("openid", "profile", "email").build();
 		this.authorizationResponse = success().build();
@@ -204,139 +205,20 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 	}
 
 	@Test
-	public void authenticateWhenIdTokenIssuerClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception {
+	public void authenticateWhenIdTokenValidationErrorThenThrowOAuth2AuthenticationException() {
 		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_id_token"));
-
-		Map<String, Object> claims = new HashMap<>();
-		claims.put(IdTokenClaimNames.SUB, "subject1");
-
-		this.setUpIdToken(claims);
-
-		this.authenticationProvider.authenticate(
-			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
-	}
-
-	@Test
-	public void authenticateWhenIdTokenSubjectClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_id_token"));
-
-		Map<String, Object> claims = new HashMap<>();
-		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
-
-		this.setUpIdToken(claims);
-
-		this.authenticationProvider.authenticate(
-			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
-	}
-
-	@Test
-	public void authenticateWhenIdTokenAudienceClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_id_token"));
-
-		Map<String, Object> claims = new HashMap<>();
-		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
-		claims.put(IdTokenClaimNames.SUB, "subject1");
-
-		this.setUpIdToken(claims);
-
-		this.authenticationProvider.authenticate(
-			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
-	}
-
-	@Test
-	public void authenticateWhenIdTokenAudienceClaimDoesNotContainClientIdThenThrowOAuth2AuthenticationException() throws Exception {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_id_token"));
-
-		Map<String, Object> claims = new HashMap<>();
-		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
-		claims.put(IdTokenClaimNames.SUB, "subject1");
-		claims.put(IdTokenClaimNames.AUD, Collections.singletonList("other-client"));
-
-		this.setUpIdToken(claims);
-
-		this.authenticationProvider.authenticate(
-			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
-	}
-
-	@Test
-	public void authenticateWhenIdTokenMultipleAudienceClaimAndAuthorizedPartyClaimIsNullThenThrowOAuth2AuthenticationException() throws Exception {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_id_token"));
-
-		Map<String, Object> claims = new HashMap<>();
-		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
-		claims.put(IdTokenClaimNames.SUB, "subject1");
-		claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
-
-		this.setUpIdToken(claims);
+		this.exception.expectMessage(containsString("[invalid_id_token] ID Token Validation Error"));
 
-		this.authenticationProvider.authenticate(
-			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
-	}
-
-	@Test
-	public void authenticateWhenIdTokenAuthorizedPartyClaimNotEqualToClientIdThenThrowOAuth2AuthenticationException() throws Exception {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_id_token"));
-
-		Map<String, Object> claims = new HashMap<>();
-		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
-		claims.put(IdTokenClaimNames.SUB, "subject1");
-		claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
-		claims.put(IdTokenClaimNames.AZP, "other-client");
-
-		this.setUpIdToken(claims);
-
-		this.authenticationProvider.authenticate(
-			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
-	}
-
-	@Test
-	public void authenticateWhenIdTokenExpiresAtIsBeforeNowThenThrowOAuth2AuthenticationException() throws Exception {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_id_token"));
-
-		Map<String, Object> claims = new HashMap<>();
-		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
-		claims.put(IdTokenClaimNames.SUB, "subject1");
-		claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
-		claims.put(IdTokenClaimNames.AZP, "client1");
-
-		Instant issuedAt = Instant.now().minusSeconds(10);
-		Instant expiresAt = Instant.from(issuedAt).plusSeconds(5);
-
-		this.setUpIdToken(claims, issuedAt, expiresAt);
-
-		this.authenticationProvider.authenticate(
-			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
-	}
-
-	@Test
-	public void authenticateWhenIdTokenIssuedAtIsAfterMaxIssuedAtThenThrowOAuth2AuthenticationException() throws Exception {
-		this.exception.expect(OAuth2AuthenticationException.class);
-		this.exception.expectMessage(containsString("invalid_id_token"));
-
-		Map<String, Object> claims = new HashMap<>();
-		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
-		claims.put(IdTokenClaimNames.SUB, "subject1");
-		claims.put(IdTokenClaimNames.AUD, Arrays.asList("client1", "client2"));
-		claims.put(IdTokenClaimNames.AZP, "client1");
-
-		Instant issuedAt = Instant.now().plusSeconds(35);
-		Instant expiresAt = Instant.from(issuedAt).plusSeconds(60);
-
-		this.setUpIdToken(claims, issuedAt, expiresAt);
+		JwtDecoder jwtDecoder = mock(JwtDecoder.class);
+		when(jwtDecoder.decode(anyString())).thenThrow(new JwtException("ID Token Validation Error"));
+		this.authenticationProvider.setJwtDecoderFactory(registration -> jwtDecoder);
 
 		this.authenticationProvider.authenticate(
-			new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
+				new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
 	}
 
 	@Test
-	public void authenticateWhenLoginSuccessThenReturnAuthentication() throws Exception {
+	public void authenticateWhenLoginSuccessThenReturnAuthentication() {
 		Map<String, Object> claims = new HashMap<>();
 		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
 		claims.put(IdTokenClaimNames.SUB, "subject1");
@@ -365,7 +247,7 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 	}
 
 	@Test
-	public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() throws Exception {
+	public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() {
 		Map<String, Object> claims = new HashMap<>();
 		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
 		claims.put(IdTokenClaimNames.SUB, "subject1");
@@ -394,7 +276,7 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 
 	// gh-5368
 	@Test
-	public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() throws Exception {
+	public void authenticateWhenTokenSuccessResponseThenAdditionalParametersAddedToUserRequest() {
 		Map<String, Object> claims = new HashMap<>();
 		claims.put(IdTokenClaimNames.ISS, "https://provider.com");
 		claims.put(IdTokenClaimNames.SUB, "subject1");
@@ -416,13 +298,13 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
 				this.accessTokenResponse.getAdditionalParameters());
 	}
 
-	private void setUpIdToken(Map<String, Object> claims) throws Exception {
+	private void setUpIdToken(Map<String, Object> claims) {
 		Instant issuedAt = Instant.now();
 		Instant expiresAt = Instant.from(issuedAt).plusSeconds(3600);
 		this.setUpIdToken(claims, issuedAt, expiresAt);
 	}
 
-	private void setUpIdToken(Map<String, Object> claims, Instant issuedAt, Instant expiresAt) throws Exception {
+	private void setUpIdToken(Map<String, Object> claims, Instant issuedAt, Instant expiresAt) {
 		Map<String, Object> headers = new HashMap<>();
 		headers.put("alg", "RS256");
 

+ 17 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeReactiveAuthenticationManagerTests.java

@@ -44,6 +44,7 @@ import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames
 import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
 import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtException;
 import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
 import reactor.core.publisher.Mono;
 
@@ -143,6 +144,22 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
 				.isInstanceOf(OAuth2AuthenticationException.class);
 	}
 
+	@Test
+	public void authenticateWhenIdTokenValidationErrorThenOAuth2AuthenticationException() {
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
+				.tokenType(OAuth2AccessToken.TokenType.BEARER)
+				.additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue()))
+				.build();
+		when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
+
+		when(this.jwtDecoder.decode(any())).thenThrow(new JwtException("ID Token Validation Error"));
+		this.manager.setJwtDecoderFactory(c -> this.jwtDecoder);
+
+		assertThatThrownBy(() -> this.manager.authenticate(loginToken()).block())
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.hasMessageContaining("[invalid_id_token] ID Token Validation Error");
+	}
+
 	@Test
 	public void authenticationWhenOAuth2UserNotFoundThenEmpty() {
 		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")

+ 186 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java

@@ -0,0 +1,186 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client.oidc.authentication;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
+import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
+import org.springframework.security.oauth2.jwt.Jwt;
+
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author Rob Winch
+ * @author Joe Grandja
+ * @since 5.1
+ */
+public class OidcIdTokenValidatorTests {
+	private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration();
+	private Map<String, Object> headers = new HashMap<>();
+	private Map<String, Object> claims = new HashMap<>();
+	private Instant issuedAt = Instant.now();
+	private Instant expiresAt = this.issuedAt.plusSeconds(3600);
+
+	@Before
+	public void setup() {
+		this.headers.put("alg", JwsAlgorithms.RS256);
+		this.claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
+		this.claims.put(IdTokenClaimNames.SUB, "rob");
+		this.claims.put(IdTokenClaimNames.AUD, Collections.singletonList("client-id"));
+	}
+
+	@Test
+	public void validateIdTokenWhenValidThenNoErrors() {
+		assertThat(this.validateIdToken()).isEmpty();
+	}
+
+	@Test
+	public void validateIdTokenWhenIssuerNullThenHasErrors() {
+		this.claims.remove(IdTokenClaimNames.ISS);
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getErrorCode)
+				.contains("invalid_id_token");
+	}
+
+	@Test
+	public void validateIdTokenWhenSubNullThenHasErrors() {
+		this.claims.remove(IdTokenClaimNames.SUB);
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getErrorCode)
+				.contains("invalid_id_token");
+	}
+
+	@Test
+	public void validateIdTokenWhenAudNullThenHasErrors() {
+		this.claims.remove(IdTokenClaimNames.AUD);
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getErrorCode)
+				.contains("invalid_id_token");
+	}
+
+	@Test
+	public void validateIdTokenWhenIssuedAtNullThenHasErrors() {
+		this.issuedAt = null;
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getErrorCode)
+				.contains("invalid_id_token");
+	}
+
+	@Test
+	public void validateIdTokenWhenExpiresAtNullThenHasErrors() {
+		this.expiresAt = null;
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getErrorCode)
+				.contains("invalid_id_token");
+	}
+
+	@Test
+	public void validateIdTokenWhenAudMultipleAndAzpNullThenHasErrors() {
+		this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other"));
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getErrorCode)
+				.contains("invalid_id_token");
+	}
+
+	@Test
+	public void validateIdTokenWhenAzpNotClientIdThenHasErrors() {
+		this.claims.put(IdTokenClaimNames.AZP, "other");
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getErrorCode)
+				.contains("invalid_id_token");
+	}
+
+	@Test
+	public void validateIdTokenWhenMultipleAudAzpClientIdThenNoErrors() {
+		this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other"));
+		this.claims.put(IdTokenClaimNames.AZP, "client-id");
+		assertThat(this.validateIdToken()).isEmpty();
+	}
+
+	@Test
+	public void validateIdTokenWhenMultipleAudAzpNotClientIdThenHasErrors() {
+		this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id-1", "client-id-2"));
+		this.claims.put(IdTokenClaimNames.AZP, "other-client");
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getErrorCode)
+				.contains("invalid_id_token");
+	}
+
+	@Test
+	public void validateIdTokenWhenAudNotClientIdThenHasErrors() {
+		this.claims.put(IdTokenClaimNames.AUD, Collections.singletonList("other-client"));
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getErrorCode)
+				.contains("invalid_id_token");
+	}
+
+	@Test
+	public void validateIdTokenWhenExpiredThenHasErrors() {
+		this.issuedAt = Instant.now().minus(Duration.ofMinutes(1));
+		this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1));
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getErrorCode)
+				.contains("invalid_id_token");
+	}
+
+	@Test
+	public void validateIdTokenWhenIssuedAtWayInFutureThenHasErrors() {
+		this.issuedAt = Instant.now().plus(Duration.ofMinutes(5));
+		this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1));
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getErrorCode)
+				.contains("invalid_id_token");
+	}
+
+	@Test
+	public void validateIdTokenWhenExpiresAtBeforeNowThenHasErrors() {
+		this.issuedAt = Instant.now().minusSeconds(10);
+		this.expiresAt = Instant.from(this.issuedAt).plusSeconds(5);
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getErrorCode)
+				.contains("invalid_id_token");
+	}
+
+	private Collection<OAuth2Error> validateIdToken() {
+		Jwt idToken = new Jwt("token123", this.issuedAt, this.expiresAt, this.headers, this.claims);
+		OidcIdTokenValidator validator = new OidcIdTokenValidator(this.registration.build());
+		return validator.validate(idToken).getErrors();
+	}
+}

+ 0 - 137
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcTokenValidatorTests.java

@@ -1,137 +0,0 @@
-/*
- * Copyright 2002-2018 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.security.oauth2.client.oidc.authentication;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.springframework.security.oauth2.client.registration.ClientRegistration;
-import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
-import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
-import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
-import org.springframework.security.oauth2.core.oidc.OidcIdToken;
-
-import java.time.Duration;
-import java.time.Instant;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
-
-import static org.assertj.core.api.Assertions.assertThatCode;
-
-/**
- * @author Rob Winch
- * @since 5.1
- */
-public class OidcTokenValidatorTests {
-	private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration();
-
-	private Map<String, Object> claims = new HashMap<>();
-	private Instant issuedAt = Instant.now();
-	private Instant expiresAt = Instant.now().plusSeconds(3600);
-
-	@Before
-	public void setup() {
-		this.claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
-		this.claims.put(IdTokenClaimNames.SUB, "rob");
-		this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id"));
-	}
-
-	@Test
-	public void validateIdTokenWhenValidThenNoException() {
-		assertThatCode(() -> validateIdToken())
-				.doesNotThrowAnyException();
-	}
-
-	@Test
-	public void validateIdTokenWhenIssuerNullThenException() {
-		this.claims.remove(IdTokenClaimNames.ISS);
-		assertThatCode(() -> validateIdToken())
-				.isInstanceOf(OAuth2AuthenticationException.class);
-	}
-
-	@Test
-	public void validateIdTokenWhenSubNullThenException() {
-		this.claims.remove(IdTokenClaimNames.SUB);
-		assertThatCode(() -> validateIdToken())
-				.isInstanceOf(OAuth2AuthenticationException.class);
-	}
-
-	@Test
-	public void validateIdTokenWhenAudNullThenException() {
-		this.claims.remove(IdTokenClaimNames.AUD);
-		assertThatCode(() -> validateIdToken())
-				.isInstanceOf(OAuth2AuthenticationException.class);
-	}
-
-	@Test
-	public void validateIdTokenWhenIssuedAtNullThenException() {
-		this.issuedAt = null;
-		assertThatCode(() -> validateIdToken())
-				.isInstanceOf(OAuth2AuthenticationException.class);
-	}
-
-	@Test
-	public void validateIdTokenWhenExpiresAtNullThenException() {
-		this.expiresAt = null;
-		assertThatCode(() -> validateIdToken())
-				.isInstanceOf(OAuth2AuthenticationException.class);
-	}
-
-	@Test
-	public void validateIdTokenWhenAudMultipleAndAzpNullThenException() {
-		this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other"));
-		assertThatCode(() -> validateIdToken())
-				.isInstanceOf(OAuth2AuthenticationException.class);
-	}
-
-	@Test
-	public void validateIdTokenWhenAzpNotClientIdThenException() {
-		this.claims.put(IdTokenClaimNames.AZP, "other");
-		assertThatCode(() -> validateIdToken())
-				.isInstanceOf(OAuth2AuthenticationException.class);
-	}
-
-	@Test
-	public void validateIdTokenWhenMulitpleAudAzpClientIdThenNoException() {
-		this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other"));
-		this.claims.put(IdTokenClaimNames.AZP, "client-id");
-		assertThatCode(() -> validateIdToken())
-				.doesNotThrowAnyException();
-	}
-
-	@Test
-	public void validateIdTokenWhenExpiredThenException() {
-		this.issuedAt = Instant.now().minus(Duration.ofMinutes(1));
-		this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1));
-		assertThatCode(() -> validateIdToken())
-				.isInstanceOf(OAuth2AuthenticationException.class);
-	}
-
-	@Test
-	public void validateIdTokenWhenIssuedAtWayInFutureThenException() {
-		this.issuedAt = Instant.now().plus(Duration.ofMinutes(5));
-		this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1));
-		assertThatCode(() -> validateIdToken())
-				.isInstanceOf(OAuth2AuthenticationException.class);
-	}
-
-	private void validateIdToken() {
-		OidcIdToken token = new OidcIdToken("token123", this.issuedAt, this.expiresAt, this.claims);
-		OidcTokenValidator.validateIdToken(token, this.registration.build());
-	}
-
-}