浏览代码

Allow subclassing OAuth2AuthenticationContext

Closes gh-492
Joe Grandja 3 年之前
父节点
当前提交
5fa1e8e3b1

+ 109 - 9
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/authentication/OAuth2AuthenticationContext.java

@@ -15,8 +15,10 @@
  */
 package org.springframework.security.oauth2.core.authentication;
 
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.function.Consumer;
 
 import org.springframework.lang.Nullable;
 import org.springframework.security.core.Authentication;
@@ -25,15 +27,13 @@ import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 
 /**
- * A context that holds an {@link Authentication} and (optionally) additional information
- * and is used by an {@link OAuth2AuthenticationValidator} when attempting to validate the {@link Authentication}.
+ * A context that holds an {@link Authentication} and (optionally) additional information.
  *
  * @author Joe Grandja
  * @since 0.2.0
  * @see Context
- * @see OAuth2AuthenticationValidator
  */
-public final class OAuth2AuthenticationContext implements Context {
+public class OAuth2AuthenticationContext implements Context {
 	private final Map<Object, Object> context;
 
 	/**
@@ -41,18 +41,27 @@ public final class OAuth2AuthenticationContext implements Context {
 	 *
 	 * @param authentication the {@code Authentication}
 	 * @param context a {@code Map} of additional context information
+	 * @deprecated Use {@link #with(Authentication)} instead
 	 */
+	@Deprecated
 	public OAuth2AuthenticationContext(Authentication authentication, @Nullable Map<Object, Object> context) {
 		Assert.notNull(authentication, "authentication cannot be null");
-		this.context = new HashMap<>();
+		Map<Object, Object> ctx = new HashMap<>();
 		if (!CollectionUtils.isEmpty(context)) {
-			this.context.putAll(context);
+			ctx.putAll(context);
 		}
-		this.context.put(Authentication.class, authentication);
+		ctx.put(Authentication.class, authentication);
+		this.context = Collections.unmodifiableMap(ctx);
+	}
+
+	protected OAuth2AuthenticationContext(Map<Object, Object> context) {
+		Assert.notEmpty(context, "context cannot be empty");
+		Assert.notNull(context.get(Authentication.class), "authentication cannot be null");
+		this.context = Collections.unmodifiableMap(new HashMap<>(context));
 	}
 
 	/**
-	 * Returns the {@link Authentication} associated to the authentication context.
+	 * Returns the {@link Authentication} associated to the context.
 	 *
 	 * @param <T> the type of the {@code Authentication}
 	 * @return the {@link Authentication}
@@ -63,14 +72,105 @@ public final class OAuth2AuthenticationContext implements Context {
 	}
 
 	@SuppressWarnings("unchecked")
+	@Nullable
 	@Override
 	public <V> V get(Object key) {
-		return (V) this.context.get(key);
+		return hasKey(key) ? (V) this.context.get(key) : null;
 	}
 
 	@Override
 	public boolean hasKey(Object key) {
+		Assert.notNull(key, "key cannot be null");
 		return this.context.containsKey(key);
 	}
 
+	/**
+	 * Constructs a new {@link Builder} with the provided {@link Authentication}.
+	 *
+	 * @param authentication the {@link Authentication}
+	 * @return the {@link Builder}
+	 */
+	public static Builder with(Authentication authentication) {
+		return new Builder(authentication);
+	}
+
+	/**
+	 * A builder for {@link OAuth2AuthenticationContext}.
+	 */
+	public static final class Builder extends AbstractBuilder<OAuth2AuthenticationContext, Builder> {
+
+		private Builder(Authentication authentication) {
+			super(authentication);
+		}
+
+		@Override
+		public OAuth2AuthenticationContext build() {
+			return new OAuth2AuthenticationContext(getContext());
+		}
+
+	}
+
+	/**
+	 * A builder for subclasses of {@link OAuth2AuthenticationContext}.
+	 *
+	 * @param <T> the type of the authentication context
+	 * @param <B> the type of the builder
+	 */
+	protected static abstract class AbstractBuilder<T extends OAuth2AuthenticationContext, B extends AbstractBuilder<T, B>> {
+		private final Map<Object, Object> context = new HashMap<>();
+
+		protected AbstractBuilder(Authentication authentication) {
+			Assert.notNull(authentication, "authentication cannot be null");
+			put(Authentication.class, authentication);
+		}
+
+		/**
+		 * Associates an attribute.
+		 *
+		 * @param key the key for the attribute
+		 * @param value the value of the attribute
+		 * @return the {@link AbstractBuilder} for further configuration
+		 */
+		public B put(Object key, Object value) {
+			Assert.notNull(key, "key cannot be null");
+			Assert.notNull(value, "value cannot be null");
+			getContext().put(key, value);
+			return getThis();
+		}
+
+		/**
+		 * A {@code Consumer} of the attributes {@code Map}
+		 * allowing the ability to add, replace, or remove.
+		 *
+		 * @param contextConsumer a {@link Consumer} of the attributes {@code Map}
+		 * @return the {@link AbstractBuilder} for further configuration
+		 */
+		public B context(Consumer<Map<Object, Object>> contextConsumer) {
+			contextConsumer.accept(getContext());
+			return getThis();
+		}
+
+		@SuppressWarnings("unchecked")
+		protected <V> V get(Object key) {
+			return (V) getContext().get(key);
+		}
+
+		protected Map<Object, Object> getContext() {
+			return this.context;
+		}
+
+		@SuppressWarnings("unchecked")
+		protected final B getThis() {
+			return (B) this;
+		}
+
+		/**
+		 * Builds a new {@link OAuth2AuthenticationContext}.
+		 *
+		 * @return the {@link OAuth2AuthenticationContext}
+		 */
+		public abstract T build();
+
+	}
+
 }

+ 4 - 4
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

@@ -156,10 +156,10 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 					authorizationCodeRequestAuthentication, null);
 		}
 
-		Map<Object, Object> context = new HashMap<>();
-		context.put(RegisteredClient.class, registeredClient);
-		OAuth2AuthenticationContext authenticationContext = new OAuth2AuthenticationContext(
-				authorizationCodeRequestAuthentication, context);
+		OAuth2AuthenticationContext authenticationContext =
+				OAuth2AuthenticationContext.with(authorizationCodeRequestAuthentication)
+						.put(RegisteredClient.class, registeredClient)
+						.build();
 
 		OAuth2AuthenticationValidator redirectUriValidator = resolveAuthenticationValidator(OAuth2ParameterNames.REDIRECT_URI);
 		redirectUriValidator.validate(authenticationContext);

+ 5 - 5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationProvider.java

@@ -98,11 +98,11 @@ public final class OidcUserInfoAuthenticationProvider implements AuthenticationP
 			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
 		}
 
-		Map<Object, Object> context = new HashMap<>();
-		context.put(OAuth2Token.class, accessTokenAuthentication.getToken());
-		context.put(OAuth2Authorization.class, authorization);
-		OAuth2AuthenticationContext authenticationContext = new OAuth2AuthenticationContext(
-				userInfoAuthentication, context);
+		OAuth2AuthenticationContext authenticationContext =
+				OAuth2AuthenticationContext.with(userInfoAuthentication)
+						.put(OAuth2Token.class, accessTokenAuthentication.getToken())
+						.put(OAuth2Authorization.class, authorization)
+						.build();
 
 		OidcUserInfo userInfo = this.userInfoMapper.apply(authenticationContext);