Browse Source

Fix NPE with exp claim in NimbusJwtDecoderJwkSupport

Fixes gh-5168
Joe Grandja 7 years ago
parent
commit
d8f91e4261

+ 25 - 15
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AbstractOAuth2Token.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
  */
 package org.springframework.security.oauth2.core;
 
+import org.springframework.lang.Nullable;
 import org.springframework.security.core.SpringSecurityCoreVersion;
 import org.springframework.util.Assert;
 
@@ -38,14 +39,23 @@ public abstract class AbstractOAuth2Token implements Serializable {
 	 * Sub-class constructor.
 	 *
 	 * @param tokenValue the token value
-	 * @param issuedAt the time at which the token was issued
-	 * @param expiresAt the expiration time on or after which the token MUST NOT be accepted
 	 */
-	protected AbstractOAuth2Token(String tokenValue, Instant issuedAt, Instant expiresAt) {
+	protected AbstractOAuth2Token(String tokenValue) {
+		this(tokenValue, null, null);
+	}
+
+	/**
+	 * Sub-class constructor.
+	 *
+	 * @param tokenValue the token value
+	 * @param issuedAt the time at which the token was issued, may be null
+	 * @param expiresAt the expiration time on or after which the token MUST NOT be accepted, may be null
+	 */
+	protected AbstractOAuth2Token(String tokenValue, @Nullable Instant issuedAt, @Nullable Instant expiresAt) {
 		Assert.hasText(tokenValue, "tokenValue cannot be empty");
-		Assert.notNull(issuedAt, "issuedAt cannot be null");
-		Assert.notNull(expiresAt, "expiresAt cannot be null");
-		Assert.isTrue(expiresAt.isAfter(issuedAt), "expiresAt must be after issuedAt");
+		if (issuedAt != null && expiresAt != null) {
+			Assert.isTrue(expiresAt.isAfter(issuedAt), "expiresAt must be after issuedAt");
+		}
 		this.tokenValue = tokenValue;
 		this.issuedAt = issuedAt;
 		this.expiresAt = expiresAt;
@@ -63,18 +73,18 @@ public abstract class AbstractOAuth2Token implements Serializable {
 	/**
 	 * Returns the time at which the token was issued.
 	 *
-	 * @return the time the token was issued
+	 * @return the time the token was issued or null
 	 */
-	public Instant getIssuedAt() {
+	public @Nullable Instant getIssuedAt() {
 		return this.issuedAt;
 	}
 
 	/**
 	 * Returns the expiration time on or after which the token MUST NOT be accepted.
 	 *
-	 * @return the expiration time of the token
+	 * @return the expiration time of the token or null
 	 */
-	public Instant getExpiresAt() {
+	public @Nullable Instant getExpiresAt() {
 		return this.expiresAt;
 	}
 
@@ -92,17 +102,17 @@ public abstract class AbstractOAuth2Token implements Serializable {
 		if (!this.getTokenValue().equals(that.getTokenValue())) {
 			return false;
 		}
-		if (!this.getIssuedAt().equals(that.getIssuedAt())) {
+		if (this.getIssuedAt() != null ? !this.getIssuedAt().equals(that.getIssuedAt()) : that.getIssuedAt() != null) {
 			return false;
 		}
-		return this.getExpiresAt().equals(that.getExpiresAt());
+		return this.getExpiresAt() != null ? this.getExpiresAt().equals(that.getExpiresAt()) : that.getExpiresAt() == null;
 	}
 
 	@Override
 	public int hashCode() {
 		int result = this.getTokenValue().hashCode();
-		result = 31 * result + this.getIssuedAt().hashCode();
-		result = 31 * result + this.getExpiresAt().hashCode();
+		result = 31 * result + (this.getIssuedAt() != null ? this.getIssuedAt().hashCode() : 0);
+		result = 31 * result + (this.getExpiresAt() != null ? this.getExpiresAt().hashCode() : 0);
 		return result;
 	}
 }

+ 0 - 10
oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/OAuth2AccessTokenTests.java

@@ -51,16 +51,6 @@ public class OAuth2AccessTokenTests {
 		new OAuth2AccessToken(TOKEN_TYPE, null, ISSUED_AT, EXPIRES_AT);
 	}
 
-	@Test(expected = IllegalArgumentException.class)
-	public void constructorWhenIssuedAtIsNullThenThrowIllegalArgumentException() {
-		new OAuth2AccessToken(TOKEN_TYPE, TOKEN_VALUE, null, EXPIRES_AT);
-	}
-
-	@Test(expected = IllegalArgumentException.class)
-	public void constructorWhenExpiresAtIsNullThenThrowIllegalArgumentException() {
-		new OAuth2AccessToken(TOKEN_TYPE, TOKEN_VALUE, ISSUED_AT, null);
-	}
-
 	@Test(expected = IllegalArgumentException.class)
 	public void constructorWhenIssuedAtAfterExpiresAtThenThrowIllegalArgumentException() {
 		new OAuth2AccessToken(TOKEN_TYPE, TOKEN_VALUE, Instant.from(EXPIRES_AT).plusSeconds(1), EXPIRES_AT);

+ 0 - 10
oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/OidcIdTokenTests.java

@@ -82,16 +82,6 @@ public class OidcIdTokenTests {
 		new OidcIdToken(null, Instant.ofEpochMilli(IAT_VALUE), Instant.ofEpochMilli(EXP_VALUE), CLAIMS);
 	}
 
-	@Test(expected = IllegalArgumentException.class)
-	public void constructorWhenIssuedAtIsNullThenThrowIllegalArgumentException() {
-		new OidcIdToken(ID_TOKEN_VALUE, null, Instant.ofEpochMilli(EXP_VALUE), CLAIMS);
-	}
-
-	@Test(expected = IllegalArgumentException.class)
-	public void constructorWhenExpiresAtIsNullThenThrowIllegalArgumentException() {
-		new OidcIdToken(ID_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE), null, CLAIMS);
-	}
-
 	@Test(expected = IllegalArgumentException.class)
 	public void constructorWhenClaimsIsEmptyThenThrowIllegalArgumentException() {
 		new OidcIdToken(ID_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE),

+ 2 - 0
oauth2/oauth2-jose/spring-security-oauth2-jose.gradle

@@ -5,4 +5,6 @@ dependencies {
 	compile project(':spring-security-oauth2-core')
 	compile springCoreDependency
 	compile 'com.nimbusds:nimbus-jose-jwt'
+
+	testCompile powerMock2Dependencies
 }

+ 7 - 4
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupport.java

@@ -103,12 +103,15 @@ public final class NimbusJwtDecoderJwkSupport implements JwtDecoder {
 			// Verify the signature
 			JWTClaimsSet jwtClaimsSet = this.jwtProcessor.process(parsedJwt, null);
 
-			Instant expiresAt = jwtClaimsSet.getExpirationTime().toInstant();
-			Instant issuedAt;
+			Instant expiresAt = null;
+			if (jwtClaimsSet.getExpirationTime() != null) {
+				expiresAt = jwtClaimsSet.getExpirationTime().toInstant();
+			}
+			Instant issuedAt = null;
 			if (jwtClaimsSet.getIssueTime() != null) {
 				issuedAt = jwtClaimsSet.getIssueTime().toInstant();
-			} else {
-				// issuedAt is required in AbstractOAuth2Token so let's default to expiresAt - 1 second
+			} else if (expiresAt != null) {
+				// Default to expiresAt - 1 second
 				issuedAt = Instant.from(expiresAt).minusSeconds(1);
 			}
 

+ 0 - 10
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtTests.java

@@ -72,16 +72,6 @@ public class JwtTests {
 		new Jwt(null, Instant.ofEpochMilli(IAT_VALUE), Instant.ofEpochMilli(EXP_VALUE), HEADERS, CLAIMS);
 	}
 
-	@Test(expected = IllegalArgumentException.class)
-	public void constructorWhenIssuedAtIsNullThenThrowIllegalArgumentException() {
-		new Jwt(JWT_TOKEN_VALUE, null, Instant.ofEpochMilli(EXP_VALUE), HEADERS, CLAIMS);
-	}
-
-	@Test(expected = IllegalArgumentException.class)
-	public void constructorWhenExpiresAtIsNullThenThrowIllegalArgumentException() {
-		new Jwt(JWT_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE), null, HEADERS, CLAIMS);
-	}
-
 	@Test(expected = IllegalArgumentException.class)
 	public void constructorWhenHeadersIsEmptyThenThrowIllegalArgumentException() {
 		new Jwt(JWT_TOKEN_VALUE, Instant.ofEpochMilli(IAT_VALUE),

+ 49 - 8
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java

@@ -15,36 +15,77 @@
  */
 package org.springframework.security.oauth2.jwt;
 
+import com.nimbusds.jose.JWSAlgorithm;
+import com.nimbusds.jose.JWSHeader;
+import com.nimbusds.jwt.JWT;
+import com.nimbusds.jwt.JWTClaimsSet;
+import com.nimbusds.jwt.JWTParser;
+import com.nimbusds.jwt.proc.DefaultJWTProcessor;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
 
+import static org.assertj.core.api.AssertionsForClassTypes.assertThatCode;
+import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.*;
+import static org.mockito.Mockito.mock;
+import static org.powermock.api.mockito.PowerMockito.*;
+
 /**
  * Tests for {@link NimbusJwtDecoderJwkSupport}.
  *
  * @author Joe Grandja
  */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({NimbusJwtDecoderJwkSupport.class, JWTParser.class})
 public class NimbusJwtDecoderJwkSupportTests {
 	private static final String JWK_SET_URL = "https://provider.com/oauth2/keys";
 	private static final String JWS_ALGORITHM = JwsAlgorithms.RS256;
 
-	@Test(expected = IllegalArgumentException.class)
+	@Test
 	public void constructorWhenJwkSetUrlIsNullThenThrowIllegalArgumentException() {
-		new NimbusJwtDecoderJwkSupport(null);
+		assertThatThrownBy(() -> new NimbusJwtDecoderJwkSupport(null))
+				.isInstanceOf(IllegalArgumentException.class);
 	}
 
-	@Test(expected = IllegalArgumentException.class)
+	@Test
 	public void constructorWhenJwkSetUrlInvalidThenThrowIllegalArgumentException() {
-		new NimbusJwtDecoderJwkSupport("invalid.com");
+		assertThatThrownBy(() -> new NimbusJwtDecoderJwkSupport("invalid.com"))
+				.isInstanceOf(IllegalArgumentException.class);
 	}
 
-	@Test(expected = IllegalArgumentException.class)
+	@Test
 	public void constructorWhenJwsAlgorithmIsNullThenThrowIllegalArgumentException() {
-		new NimbusJwtDecoderJwkSupport(JWK_SET_URL, null);
+		assertThatThrownBy(() -> new NimbusJwtDecoderJwkSupport(JWK_SET_URL, null))
+				.isInstanceOf(IllegalArgumentException.class);
 	}
 
-	@Test(expected = JwtException.class)
+	@Test
 	public void decodeWhenJwtInvalidThenThrowJwtException() {
 		NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM);
-		jwtDecoder.decode("invalid");
+		assertThatThrownBy(() -> jwtDecoder.decode("invalid"))
+				.isInstanceOf(JwtException.class);
+	}
+
+	// gh-5168
+	@Test
+	public void decodeWhenExpClaimNullThenDoesNotThrowException() throws Exception {
+		JWT jwt = mock(JWT.class);
+		JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.parse(JWS_ALGORITHM)).build();
+		when(jwt.getHeader()).thenReturn(header);
+
+		mockStatic(JWTParser.class);
+		when(JWTParser.parse(anyString())).thenReturn(jwt);
+
+		DefaultJWTProcessor jwtProcessor = mock(DefaultJWTProcessor.class);
+		whenNew(DefaultJWTProcessor.class).withAnyArguments().thenReturn(jwtProcessor);
+
+		JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().audience("resource1").build();
+		when(jwtProcessor.process(any(JWT.class), eq(null))).thenReturn(jwtClaimsSet);
+
+		NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(JWK_SET_URL, JWS_ALGORITHM);
+		assertThatCode(() -> jwtDecoder.decode("encoded-jwt")).doesNotThrowAnyException();
 	}
 }