Sfoglia il codice sorgente

Improve error messages in OidcIdTokenValidator

This commit ensures that error messages contain more specific
information regarding the reported error.

Fixes: gh-6323
Rafael Dominguez 6 anni fa
parent
commit
057ed616c4

+ 53 - 33
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -27,7 +27,10 @@ import org.springframework.util.CollectionUtils;
 
 import java.net.URL;
 import java.time.Instant;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
 
 /**
  * An {@link OAuth2TokenValidator} responsible for
@@ -41,7 +44,6 @@ import java.util.List;
  * @see <a target="_blank" href="http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation">ID Token Validation</a>
  */
 public final class OidcIdTokenValidator implements OAuth2TokenValidator<Jwt> {
-	private static final OAuth2Error INVALID_ID_TOKEN_ERROR = new OAuth2Error("invalid_id_token");
 	private final ClientRegistration clientRegistration;
 
 	public OidcIdTokenValidator(ClientRegistration clientRegistration) {
@@ -53,27 +55,10 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator<Jwt> {
 	public OAuth2TokenValidatorResult validate(Jwt idToken) {
 		// 3.1.3.7  ID Token Validation
 		// http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
+		Map<String, Object> invalidClaims = validateRequiredClaims(idToken);
 
-		// Validate REQUIRED Claims
-		URL issuer = idToken.getIssuer();
-		if (issuer == null) {
-			return invalidIdToken();
-		}
-		String subject = idToken.getSubject();
-		if (subject == null) {
-			return invalidIdToken();
-		}
-		List<String> audience = idToken.getAudience();
-		if (CollectionUtils.isEmpty(audience)) {
-			return invalidIdToken();
-		}
-		Instant expiresAt = idToken.getExpiresAt();
-		if (expiresAt == null) {
-			return invalidIdToken();
-		}
-		Instant issuedAt = idToken.getIssuedAt();
-		if (issuedAt == null) {
-			return invalidIdToken();
+		if (!invalidClaims.isEmpty()){
+			return  OAuth2TokenValidatorResult.failure(invalidIdToken(invalidClaims));
 		}
 
 		// 2. The Issuer Identifier for the OpenID Provider (which is typically obtained during Discovery)
@@ -85,21 +70,21 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator<Jwt> {
 		// The aud (audience) Claim MAY contain an array with more than one element.
 		// The ID Token MUST be rejected if the ID Token does not list the Client as a valid audience,
 		// or if it contains additional audiences not trusted by the Client.
-		if (!audience.contains(this.clientRegistration.getClientId())) {
-			return invalidIdToken();
+		if (!idToken.getAudience().contains(this.clientRegistration.getClientId())) {
+			invalidClaims.put(IdTokenClaimNames.AUD, idToken.getAudience());
 		}
 
 		// 4. If the ID Token contains multiple audiences,
 		// the Client SHOULD verify that an azp Claim is present.
 		String authorizedParty = idToken.getClaimAsString(IdTokenClaimNames.AZP);
-		if (audience.size() > 1 && authorizedParty == null) {
-			return invalidIdToken();
+		if (idToken.getAudience().size() > 1 && authorizedParty == null) {
+			invalidClaims.put(IdTokenClaimNames.AZP, authorizedParty);
 		}
 
 		// 5. If an azp (authorized party) Claim is present,
 		// the Client SHOULD verify that its client_id is the Claim Value.
 		if (authorizedParty != null && !authorizedParty.equals(this.clientRegistration.getClientId())) {
-			return invalidIdToken();
+			invalidClaims.put(IdTokenClaimNames.AZP, authorizedParty);
 		}
 
 		// 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the Client
@@ -108,16 +93,16 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator<Jwt> {
 
 		// 9. The current time MUST be before the time represented by the exp Claim.
 		Instant now = Instant.now();
-		if (!now.isBefore(expiresAt)) {
-			return invalidIdToken();
+		if (!now.isBefore(idToken.getExpiresAt())) {
+			invalidClaims.put(IdTokenClaimNames.EXP, idToken.getExpiresAt());
 		}
 
 		// 10. The iat Claim can be used to reject tokens that were issued too far away from the current time,
 		// limiting the amount of time that nonces need to be stored to prevent attacks.
 		// The acceptable range is Client specific.
 		Instant maxIssuedAt = now.plusSeconds(30);
-		if (issuedAt.isAfter(maxIssuedAt)) {
-			return invalidIdToken();
+		if (idToken.getIssuedAt().isAfter(maxIssuedAt)) {
+			invalidClaims.put(IdTokenClaimNames.IAT, idToken.getIssuedAt());
 		}
 
 		// 11. If a nonce value was sent in the Authentication Request,
@@ -127,10 +112,45 @@ public final class OidcIdTokenValidator implements OAuth2TokenValidator<Jwt> {
 		// The precise method for detecting replay attacks is Client specific.
 		// TODO Depends on gh-4442
 
+		if (!invalidClaims.isEmpty()) {
+			return OAuth2TokenValidatorResult.failure(invalidIdToken(invalidClaims));
+		}
+
 		return OAuth2TokenValidatorResult.success();
 	}
 
-	private static OAuth2TokenValidatorResult invalidIdToken() {
-		return OAuth2TokenValidatorResult.failure(INVALID_ID_TOKEN_ERROR);
+	private static OAuth2Error invalidIdToken(Map<String, Object> invalidClaims) {
+		String claimsDetail = invalidClaims.entrySet().stream()
+				.map(it -> it.getKey()+ "("+it.getValue()+")")
+				.collect(Collectors.joining(", "));
+
+		return new OAuth2Error("invalid_id_token", "The ID Token contains invalid claims: "+claimsDetail, null);
+	}
+
+	private static Map<String, Object>  validateRequiredClaims(Jwt idToken){
+		Map<String, Object> requiredClaims = new HashMap<>();
+
+		URL issuer = idToken.getIssuer();
+		if (issuer == null) {
+			requiredClaims.put(IdTokenClaimNames.ISS, issuer);
+		}
+		String subject = idToken.getSubject();
+		if (subject == null) {
+			requiredClaims.put(IdTokenClaimNames.SUB, subject);
+		}
+		List<String> audience = idToken.getAudience();
+		if (CollectionUtils.isEmpty(audience)) {
+			requiredClaims.put(IdTokenClaimNames.AUD, audience);
+		}
+		Instant expiresAt = idToken.getExpiresAt();
+		if (expiresAt == null) {
+			requiredClaims.put(IdTokenClaimNames.EXP, expiresAt);
+		}
+		Instant issuedAt = idToken.getIssuedAt();
+		if (issuedAt == null) {
+			requiredClaims.put(IdTokenClaimNames.IAT, issuedAt);
+		}
+
+		return requiredClaims;
 	}
 }

+ 52 - 25
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -64,8 +64,9 @@ public class OidcIdTokenValidatorTests {
 		this.claims.remove(IdTokenClaimNames.ISS);
 		assertThat(this.validateIdToken())
 				.hasSize(1)
-				.extracting(OAuth2Error::getErrorCode)
-				.contains("invalid_id_token");
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.ISS));
+
 	}
 
 	@Test
@@ -73,8 +74,8 @@ public class OidcIdTokenValidatorTests {
 		this.claims.remove(IdTokenClaimNames.SUB);
 		assertThat(this.validateIdToken())
 				.hasSize(1)
-				.extracting(OAuth2Error::getErrorCode)
-				.contains("invalid_id_token");
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.SUB));
 	}
 
 	@Test
@@ -82,8 +83,8 @@ public class OidcIdTokenValidatorTests {
 		this.claims.remove(IdTokenClaimNames.AUD);
 		assertThat(this.validateIdToken())
 				.hasSize(1)
-				.extracting(OAuth2Error::getErrorCode)
-				.contains("invalid_id_token");
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.AUD));
 	}
 
 	@Test
@@ -91,8 +92,8 @@ public class OidcIdTokenValidatorTests {
 		this.issuedAt = null;
 		assertThat(this.validateIdToken())
 				.hasSize(1)
-				.extracting(OAuth2Error::getErrorCode)
-				.contains("invalid_id_token");
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.IAT));
 	}
 
 	@Test
@@ -100,8 +101,8 @@ public class OidcIdTokenValidatorTests {
 		this.expiresAt = null;
 		assertThat(this.validateIdToken())
 				.hasSize(1)
-				.extracting(OAuth2Error::getErrorCode)
-				.contains("invalid_id_token");
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.EXP));
 	}
 
 	@Test
@@ -109,8 +110,8 @@ public class OidcIdTokenValidatorTests {
 		this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other"));
 		assertThat(this.validateIdToken())
 				.hasSize(1)
-				.extracting(OAuth2Error::getErrorCode)
-				.contains("invalid_id_token");
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.AZP));
 	}
 
 	@Test
@@ -118,8 +119,8 @@ public class OidcIdTokenValidatorTests {
 		this.claims.put(IdTokenClaimNames.AZP, "other");
 		assertThat(this.validateIdToken())
 				.hasSize(1)
-				.extracting(OAuth2Error::getErrorCode)
-				.contains("invalid_id_token");
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.AZP));
 	}
 
 	@Test
@@ -135,8 +136,8 @@ public class OidcIdTokenValidatorTests {
 		this.claims.put(IdTokenClaimNames.AZP, "other-client");
 		assertThat(this.validateIdToken())
 				.hasSize(1)
-				.extracting(OAuth2Error::getErrorCode)
-				.contains("invalid_id_token");
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.AZP));
 	}
 
 	@Test
@@ -144,8 +145,8 @@ public class OidcIdTokenValidatorTests {
 		this.claims.put(IdTokenClaimNames.AUD, Collections.singletonList("other-client"));
 		assertThat(this.validateIdToken())
 				.hasSize(1)
-				.extracting(OAuth2Error::getErrorCode)
-				.contains("invalid_id_token");
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.AUD));
 	}
 
 	@Test
@@ -154,8 +155,8 @@ public class OidcIdTokenValidatorTests {
 		this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1));
 		assertThat(this.validateIdToken())
 				.hasSize(1)
-				.extracting(OAuth2Error::getErrorCode)
-				.contains("invalid_id_token");
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.EXP));
 	}
 
 	@Test
@@ -164,8 +165,8 @@ public class OidcIdTokenValidatorTests {
 		this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1));
 		assertThat(this.validateIdToken())
 				.hasSize(1)
-				.extracting(OAuth2Error::getErrorCode)
-				.contains("invalid_id_token");
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.IAT));
 	}
 
 	@Test
@@ -174,8 +175,34 @@ public class OidcIdTokenValidatorTests {
 		this.expiresAt = Instant.from(this.issuedAt).plusSeconds(5);
 		assertThat(this.validateIdToken())
 				.hasSize(1)
-				.extracting(OAuth2Error::getErrorCode)
-				.contains("invalid_id_token");
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.EXP));
+	}
+
+	@Test
+	public void validateIdTokenWhenMissingClaimsThenHasErrors() {
+		this.claims.remove(IdTokenClaimNames.SUB);
+		this.claims.remove(IdTokenClaimNames.AUD);
+		this.issuedAt = null;
+		this.expiresAt = null;
+		assertThat(this.validateIdToken())
+				.hasSize(1)
+				.extracting(OAuth2Error::getDescription)
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.SUB))
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.AUD))
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.IAT))
+				.allMatch(msg -> msg.contains(IdTokenClaimNames.EXP));
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void validateIdTokenWhenNoClaimsThenHasErrors() {
+		this.claims.remove(IdTokenClaimNames.ISS);
+		this.claims.remove(IdTokenClaimNames.SUB);
+		this.claims.remove(IdTokenClaimNames.AUD);
+		this.issuedAt = null;
+		this.expiresAt = null;
+		assertThat(this.validateIdToken())
+				.hasSize(1);
 	}
 
 	private Collection<OAuth2Error> validateIdToken() {