浏览代码

Introduce OAuth2AuthenticationValidator

Closes gh-374
Joe Grandja 4 年之前
父节点
当前提交
f6c4d49b9f

+ 76 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/authentication/OAuth2AuthenticationContext.java

@@ -0,0 +1,76 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.core.authentication;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.springframework.lang.Nullable;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.context.Context;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+
+/**
+ * A context that holds an {@link Authentication} and (optionally) additional information
+ * and is used by an {@link OAuth2AuthenticationValidator} when attempting to validate the {@link Authentication}.
+ *
+ * @author Joe Grandja
+ * @since 0.2.0
+ * @see Context
+ * @see OAuth2AuthenticationValidator
+ */
+public final class OAuth2AuthenticationContext implements Context {
+	private final Map<Object, Object> context;
+
+	/**
+	 * Constructs an {@code OAuth2AuthenticationContext} using the provided parameters.
+	 *
+	 * @param authentication the {@code Authentication}
+	 * @param context a {@code Map} of additional context information
+	 */
+	public OAuth2AuthenticationContext(Authentication authentication, @Nullable Map<Object, Object> context) {
+		Assert.notNull(authentication, "authentication cannot be null");
+		this.context = new HashMap<>();
+		if (!CollectionUtils.isEmpty(context)) {
+			this.context.putAll(context);
+		}
+		this.context.put(Authentication.class, authentication);
+	}
+
+	/**
+	 * Returns the {@link Authentication} associated to the authentication context.
+	 *
+	 * @param <T> the type of the {@code Authentication}
+	 * @return the {@link Authentication}
+	 */
+	@SuppressWarnings("unchecked")
+	public <T extends Authentication> T getAuthentication() {
+		return (T) get(Authentication.class);
+	}
+
+	@SuppressWarnings("unchecked")
+	@Override
+	public <V> V get(Object key) {
+		return (V) this.context.get(key);
+	}
+
+	@Override
+	public boolean hasKey(Object key) {
+		return this.context.containsKey(key);
+	}
+
+}

+ 40 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/authentication/OAuth2AuthenticationValidator.java

@@ -0,0 +1,40 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.core.authentication;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+
+/**
+ * Implementations of this interface are responsible for validating the attribute(s)
+ * of the {@link Authentication} associated to the {@link OAuth2AuthenticationContext}.
+ *
+ * @author Joe Grandja
+ * @since 0.2.0
+ * @see OAuth2AuthenticationContext
+ */
+@FunctionalInterface
+public interface OAuth2AuthenticationValidator {
+
+	/**
+	 * Validate the attribute(s) of the {@link Authentication}.
+	 *
+	 * @param authenticationContext the authentication context
+	 * @throws OAuth2AuthenticationException if the attribute(s) of the {@code Authentication} is invalid
+	 */
+	void validate(OAuth2AuthenticationContext authenticationContext) throws OAuth2AuthenticationException;
+
+}

+ 93 - 18
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

@@ -20,8 +20,11 @@ import java.time.Instant;
 import java.time.temporal.ChronoUnit;
 import java.util.Base64;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Map;
 import java.util.Set;
+import java.util.function.Function;
 import java.util.regex.Pattern;
 
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
@@ -34,6 +37,8 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
+import org.springframework.security.oauth2.core.authentication.OAuth2AuthenticationContext;
+import org.springframework.security.oauth2.core.authentication.OAuth2AuthenticationValidator;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
@@ -68,11 +73,14 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 	private static final String PKCE_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7636#section-4.4.1";
 	private static final Pattern LOOPBACK_ADDRESS_PATTERN =
 			Pattern.compile("^127(?:\\.[0-9]+){0,2}\\.[0-9]+$|^\\[(?:0*:)*?:?0*1]$");
+	private static final Function<String, OAuth2AuthenticationValidator> DEFAULT_AUTHENTICATION_VALIDATOR_RESOLVER =
+			createDefaultAuthenticationValidatorResolver();
 	private final RegisteredClientRepository registeredClientRepository;
 	private final OAuth2AuthorizationService authorizationService;
 	private final OAuth2AuthorizationConsentService authorizationConsentService;
 	private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
 	private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
+	private Function<String, OAuth2AuthenticationValidator> authenticationValidatorResolver = DEFAULT_AUTHENTICATION_VALIDATOR_RESOLVER;
 
 	/**
 	 * Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationProvider} using the provided parameters.
@@ -106,6 +114,26 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 		return OAuth2AuthorizationCodeRequestAuthenticationToken.class.isAssignableFrom(authentication);
 	}
 
+	/**
+	 * Sets the resolver that resolves an {@link OAuth2AuthenticationValidator} from the provided OAuth 2.0 Authorization Request parameter.
+	 *
+	 * <p>
+	 * The following OAuth 2.0 Authorization Request parameters are supported:
+	 * <ol>
+	 * <li>{@link OAuth2ParameterNames#REDIRECT_URI}</li>
+	 * <li>{@link OAuth2ParameterNames#SCOPE}</li>
+	 * </ol>
+	 *
+	 * <p>
+	 * <b>NOTE:</b> The resolved {@link OAuth2AuthenticationValidator} MUST throw {@link OAuth2AuthorizationCodeRequestAuthenticationException} if validation fails.
+	 *
+	 * @param authenticationValidatorResolver the resolver that resolves an {@link OAuth2AuthenticationValidator} from the provided OAuth 2.0 Authorization Request parameter
+	 */
+	public void setAuthenticationValidatorResolver(Function<String, OAuth2AuthenticationValidator> authenticationValidatorResolver) {
+		Assert.notNull(authenticationValidatorResolver, "authenticationValidatorResolver cannot be null");
+		this.authenticationValidatorResolver = authenticationValidatorResolver;
+	}
+
 	private Authentication authenticateAuthorizationRequest(Authentication authentication) throws AuthenticationException {
 		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
 				(OAuth2AuthorizationCodeRequestAuthenticationToken) authentication;
@@ -117,29 +145,21 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 					authorizationCodeRequestAuthentication, null);
 		}
 
-		if (StringUtils.hasText(authorizationCodeRequestAuthentication.getRedirectUri())) {
-			if (!isValidRedirectUri(authorizationCodeRequestAuthentication.getRedirectUri(), registeredClient)) {
-				throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
-						authorizationCodeRequestAuthentication, registeredClient);
-			}
-		} else if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID) ||
-				registeredClient.getRedirectUris().size() != 1) {
-			// redirect_uri is REQUIRED for OpenID Connect
-			throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
-					authorizationCodeRequestAuthentication, registeredClient);
-		}
+		Map<Object, Object> context = new HashMap<>();
+		context.put(RegisteredClient.class, registeredClient);
+		OAuth2AuthenticationContext authenticationContext = new OAuth2AuthenticationContext(
+				authorizationCodeRequestAuthentication, context);
+
+		OAuth2AuthenticationValidator redirectUriValidator = resolveAuthenticationValidator(OAuth2ParameterNames.REDIRECT_URI);
+		redirectUriValidator.validate(authenticationContext);
 
 		if (!registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.AUTHORIZATION_CODE)) {
 			throwError(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, OAuth2ParameterNames.CLIENT_ID,
 					authorizationCodeRequestAuthentication, registeredClient);
 		}
 
-		Set<String> requestedScopes = authorizationCodeRequestAuthentication.getScopes();
-		Set<String> allowedScopes = registeredClient.getScopes();
-		if (!requestedScopes.isEmpty() && !allowedScopes.containsAll(requestedScopes)) {
-			throwError(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE,
-					authorizationCodeRequestAuthentication, registeredClient);
-		}
+		OAuth2AuthenticationValidator scopeValidator = resolveAuthenticationValidator(OAuth2ParameterNames.SCOPE);
+		scopeValidator.validate(authenticationContext);
 
 		// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
 		String codeChallenge = (String) authorizationCodeRequestAuthentication.getAdditionalParameters().get(PkceParameterNames.CODE_CHALLENGE);
@@ -170,7 +190,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 				.authorizationUri(authorizationCodeRequestAuthentication.getAuthorizationUri())
 				.clientId(registeredClient.getClientId())
 				.redirectUri(authorizationCodeRequestAuthentication.getRedirectUri())
-				.scopes(requestedScopes)
+				.scopes(authorizationCodeRequestAuthentication.getScopes())
 				.state(authorizationCodeRequestAuthentication.getState())
 				.additionalParameters(authorizationCodeRequestAuthentication.getAdditionalParameters())
 				.build();
@@ -227,6 +247,13 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 				.build();
 	}
 
+	private OAuth2AuthenticationValidator resolveAuthenticationValidator(String parameterName) {
+		OAuth2AuthenticationValidator authenticationValidator = this.authenticationValidatorResolver.apply(parameterName);
+		return authenticationValidator != null ?
+				authenticationValidator :
+				DEFAULT_AUTHENTICATION_VALIDATOR_RESOLVER.apply(parameterName);
+	}
+
 	private OAuth2AuthorizationCode createAuthorizationCode() {
 		Instant issuedAt = Instant.now();
 		Instant expiresAt = issuedAt.plus(5, ChronoUnit.MINUTES);		// TODO Allow configuration for authorization code time-to-live
@@ -329,6 +356,13 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 				.build();
 	}
 
+	private static Function<String, OAuth2AuthenticationValidator> createDefaultAuthenticationValidatorResolver() {
+		Map<String, OAuth2AuthenticationValidator> authenticationValidators = new HashMap<>();
+		authenticationValidators.put(OAuth2ParameterNames.REDIRECT_URI, new DefaultRedirectUriOAuth2AuthenticationValidator());
+		authenticationValidators.put(OAuth2ParameterNames.SCOPE, new DefaultScopeOAuth2AuthenticationValidator());
+		return authenticationValidators::get;
+	}
+
 	private static OAuth2Authorization.Builder authorizationBuilder(RegisteredClient registeredClient, Authentication principal,
 			OAuth2AuthorizationRequest authorizationRequest) {
 		return OAuth2Authorization.withRegisteredClient(registeredClient)
@@ -474,4 +508,45 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 				.authorizationCode(authorizationCodeRequestAuthentication.getAuthorizationCode());
 	}
 
+	private static class DefaultRedirectUriOAuth2AuthenticationValidator implements OAuth2AuthenticationValidator {
+
+		@Override
+		public void validate(OAuth2AuthenticationContext authenticationContext) {
+			OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
+					authenticationContext.getAuthentication();
+			RegisteredClient registeredClient = authenticationContext.get(RegisteredClient.class);
+
+			if (StringUtils.hasText(authorizationCodeRequestAuthentication.getRedirectUri())) {
+				if (!isValidRedirectUri(authorizationCodeRequestAuthentication.getRedirectUri(), registeredClient)) {
+					throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
+							authorizationCodeRequestAuthentication, registeredClient);
+				}
+			} else if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID) ||
+					registeredClient.getRedirectUris().size() != 1) {
+				// redirect_uri is REQUIRED for OpenID Connect
+				throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
+						authorizationCodeRequestAuthentication, registeredClient);
+			}
+		}
+
+	}
+
+	private static class DefaultScopeOAuth2AuthenticationValidator implements OAuth2AuthenticationValidator {
+
+		@Override
+		public void validate(OAuth2AuthenticationContext authenticationContext) {
+			OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
+					authenticationContext.getAuthentication();
+			RegisteredClient registeredClient = authenticationContext.get(RegisteredClient.class);
+
+			Set<String> requestedScopes = authorizationCodeRequestAuthentication.getScopes();
+			Set<String> allowedScopes = registeredClient.getScopes();
+			if (!requestedScopes.isEmpty() && !allowedScopes.containsAll(requestedScopes)) {
+				throwError(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE,
+						authorizationCodeRequestAuthentication, registeredClient);
+			}
+		}
+
+	}
+
 }

+ 35 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java

@@ -21,6 +21,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Function;
 
 import org.junit.Before;
 import org.junit.Test;
@@ -32,6 +33,7 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
+import org.springframework.security.oauth2.core.authentication.OAuth2AuthenticationValidator;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -54,6 +56,7 @@ import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -110,6 +113,13 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
 		assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeRequestAuthenticationToken.class)).isTrue();
 	}
 
+	@Test
+	public void setAuthenticationValidatorResolverWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authenticationProvider.setAuthenticationValidatorResolver(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authenticationValidatorResolver cannot be null");
+	}
+
 	@Test
 	public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@@ -477,6 +487,31 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
 		assertAuthorizationCodeRequestWithAuthorizationCodeResult(registeredClient, authentication, authenticationResult);
 	}
 
+	@Test
+	public void authenticateWhenCustomAuthenticationValidatorResolverThenUsed() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+
+		@SuppressWarnings("unchecked")
+		Function<String, OAuth2AuthenticationValidator> authenticationValidatorResolver = mock(Function.class);
+		this.authenticationProvider.setAuthenticationValidatorResolver(authenticationValidatorResolver);
+
+		OAuth2AuthorizationCodeRequestAuthenticationToken authentication =
+				authorizationCodeRequestAuthentication(registeredClient, this.principal)
+						.build();
+
+		OAuth2AuthorizationCodeRequestAuthenticationToken authenticationResult =
+				(OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+		assertAuthorizationCodeRequestWithAuthorizationCodeResult(registeredClient, authentication, authenticationResult);
+
+		ArgumentCaptor<String> parameterNameCaptor = ArgumentCaptor.forClass(String.class);
+		verify(authenticationValidatorResolver, times(2)).apply(parameterNameCaptor.capture());
+		assertThat(parameterNameCaptor.getAllValues()).containsExactly(
+				OAuth2ParameterNames.REDIRECT_URI, OAuth2ParameterNames.SCOPE);
+	}
+
 	private void assertAuthorizationCodeRequestWithAuthorizationCodeResult(
 			RegisteredClient registeredClient,
 			OAuth2AuthorizationCodeRequestAuthenticationToken authentication,