Browse Source

Add JwtValidators append to default

Implemented simplified creation of default OAuth2TokenValidator with additional validators.

Closes gh-14831
Max Batischev 1 year ago
parent
commit
ff19f04fca

+ 3 - 5
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/DefaultOidcLogoutTokenValidatorFactory.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2023 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,17 +19,15 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.cl
 import java.util.function.Function;
 
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
-import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.jwt.Jwt;
-import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
+import org.springframework.security.oauth2.jwt.JwtValidators;
 
 final class DefaultOidcLogoutTokenValidatorFactory implements Function<ClientRegistration, OAuth2TokenValidator<Jwt>> {
 
 	@Override
 	public OAuth2TokenValidator<Jwt> apply(ClientRegistration clientRegistration) {
-		return new DelegatingOAuth2TokenValidator<>(new JwtTimestampValidator(),
-				new OidcBackChannelLogoutTokenValidator(clientRegistration));
+		return JwtValidators.createDefaultWithValidators(new OidcBackChannelLogoutTokenValidator(clientRegistration));
 	}
 
 }

+ 3 - 5
config/src/main/java/org/springframework/security/config/web/server/DefaultOidcLogoutTokenValidatorFactory.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2023 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,17 +19,15 @@ package org.springframework.security.config.web.server;
 import java.util.function.Function;
 
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
-import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.jwt.Jwt;
-import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
+import org.springframework.security.oauth2.jwt.JwtValidators;
 
 final class DefaultOidcLogoutTokenValidatorFactory implements Function<ClientRegistration, OAuth2TokenValidator<Jwt>> {
 
 	@Override
 	public OAuth2TokenValidator<Jwt> apply(ClientRegistration clientRegistration) {
-		return new DelegatingOAuth2TokenValidator<>(new JwtTimestampValidator(),
-				new OidcBackChannelLogoutTokenValidator(clientRegistration));
+		return JwtValidators.createDefaultWithValidators(new OidcBackChannelLogoutTokenValidator(clientRegistration));
 	}
 
 }

+ 3 - 5
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/DefaultOidcIdTokenValidatorFactory.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,10 +19,9 @@ package org.springframework.security.oauth2.client.oidc.authentication;
 import java.util.function.Function;
 
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
-import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.jwt.Jwt;
-import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
+import org.springframework.security.oauth2.jwt.JwtValidators;
 
 /**
  * @author Joe Grandja
@@ -32,8 +31,7 @@ class DefaultOidcIdTokenValidatorFactory implements Function<ClientRegistration,
 
 	@Override
 	public OAuth2TokenValidator<Jwt> apply(ClientRegistration clientRegistration) {
-		return new DelegatingOAuth2TokenValidator<>(new JwtTimestampValidator(),
-				new OidcIdTokenValidator(clientRegistration));
+		return JwtValidators.createDefaultWithValidators(new OidcIdTokenValidator(clientRegistration));
 	}
 
 }

+ 40 - 1
oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/JwtValidators.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -22,6 +22,8 @@ import java.util.List;
 
 import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2TokenValidator;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
 
 /**
  * Provides factory methods for creating {@code OAuth2TokenValidator<Jwt>}
@@ -72,4 +74,41 @@ public final class JwtValidators {
 		return new DelegatingOAuth2TokenValidator<>(Arrays.asList(new JwtTimestampValidator()));
 	}
 
+	/**
+	 * <p>
+	 * Create a {@link Jwt} default validator with standard validators and additional
+	 * validators.
+	 * </p>
+	 * @param validators additional validators
+	 * @return - a delegating validator containing all standard validators with additional
+	 * validators
+	 * @since 6.3
+	 */
+	public static OAuth2TokenValidator<Jwt> createDefaultWithValidators(List<OAuth2TokenValidator<Jwt>> validators) {
+		Assert.notEmpty(validators, "validators cannot be null or empty");
+		List<OAuth2TokenValidator<Jwt>> tokenValidators = new ArrayList<>(validators);
+		JwtTimestampValidator jwtTimestampValidator = CollectionUtils.findValueOfType(tokenValidators,
+				JwtTimestampValidator.class);
+		if (jwtTimestampValidator == null) {
+			tokenValidators.add(new JwtTimestampValidator());
+		}
+		return new DelegatingOAuth2TokenValidator<>(tokenValidators);
+	}
+
+	/**
+	 * <p>
+	 * Create a {@link Jwt} default validator with standard validators and additional
+	 * validators.
+	 * </p>
+	 * @param validators additional validators
+	 * @return - a delegating validator containing all standard validators with additional
+	 * validators
+	 * @since 6.3
+	 */
+	public static OAuth2TokenValidator<Jwt> createDefaultWithValidators(OAuth2TokenValidator<Jwt>... validators) {
+		Assert.notEmpty(validators, "validators cannot be null or empty");
+		List<OAuth2TokenValidator<Jwt>> tokenValidators = new ArrayList<>(Arrays.asList(validators));
+		return createDefaultWithValidators(tokenValidators);
+	}
+
 }

+ 80 - 0
oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtValidatorsTests.java

@@ -0,0 +1,80 @@
+/*
+ * Copyright 2002-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.jwt;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Objects;
+
+import org.junit.jupiter.api.Test;
+
+import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
+import org.springframework.security.oauth2.core.OAuth2TokenValidator;
+import org.springframework.test.util.ReflectionTestUtils;
+import org.springframework.util.CollectionUtils;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatException;
+
+/**
+ * Tests for {@link JwtValidators}.
+ *
+ * @author Max Batischev
+ */
+public class JwtValidatorsTests {
+
+	private static final String ISS_CLAIM = "iss";
+
+	@Test
+	public void createWhenJwtIssuerValidatorIsPresentThenCreateDefaultValidatorWithJwtIssuerValidator() {
+		OAuth2TokenValidator<Jwt> validator = JwtValidators
+			.createDefaultWithValidators(new JwtIssuerValidator(ISS_CLAIM));
+
+		assertThat(containsByType(validator, JwtIssuerValidator.class)).isTrue();
+		assertThat(containsByType(validator, JwtTimestampValidator.class)).isTrue();
+	}
+
+	@Test
+	@SuppressWarnings("unchecked")
+	public void createWhenJwtTimestampValidatorIsPresentThenCreateDefaultValidatorWithOnlyOneJwtTimestampValidator() {
+		OAuth2TokenValidator<Jwt> validator = JwtValidators.createDefaultWithValidators(new JwtTimestampValidator());
+
+		DelegatingOAuth2TokenValidator<Jwt> delegatingOAuth2TokenValidator = (DelegatingOAuth2TokenValidator<Jwt>) validator;
+		Collection<OAuth2TokenValidator<Jwt>> tokenValidators = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
+			.getField(delegatingOAuth2TokenValidator, "tokenValidators");
+
+		assertThat(containsByType(validator, JwtTimestampValidator.class)).isTrue();
+		assertThat(Objects.requireNonNull(tokenValidators).size()).isEqualTo(1);
+	}
+
+	@Test
+	public void createWhenEmptyValidatorsThenThrowsException() {
+		assertThatException().isThrownBy(() -> JwtValidators.createDefaultWithValidators(Collections.emptyList()));
+	}
+
+	@SuppressWarnings("unchecked")
+	private boolean containsByType(OAuth2TokenValidator<Jwt> validator, Class<? extends OAuth2TokenValidator<?>> type) {
+		DelegatingOAuth2TokenValidator<Jwt> delegatingOAuth2TokenValidator = (DelegatingOAuth2TokenValidator<Jwt>) validator;
+		Collection<OAuth2TokenValidator<Jwt>> tokenValidators = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
+			.getField(delegatingOAuth2TokenValidator, "tokenValidators");
+
+		OAuth2TokenValidator<?> tokenValidator = CollectionUtils
+			.findValueOfType(Objects.requireNonNull(tokenValidators), type);
+		return tokenValidator != null;
+	}
+
+}