Переглянути джерело

Allow OAuth2AuthorizationRequest to be extended

Closes gh-18049
Joe Grandja 1 день тому
батько
коміт
fbf7bb3be1

+ 101 - 81
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java

@@ -16,6 +16,7 @@
 
 package org.springframework.security.oauth2.core.endpoint;
 
+import java.io.Serial;
 import java.io.Serializable;
 import java.net.URI;
 import java.nio.charset.StandardCharsets;
@@ -51,31 +52,46 @@ import org.springframework.web.util.UriUtils;
  * "https://tools.ietf.org/html/rfc6749#section-4.1.1">Section 4.1.1 Authorization Code
  * Grant Request</a>
  */
-public final class OAuth2AuthorizationRequest implements Serializable {
+public class OAuth2AuthorizationRequest implements Serializable {
 
+	@Serial
 	private static final long serialVersionUID = 620L;
 
-	private String authorizationUri;
+	private final String authorizationUri;
 
-	private AuthorizationGrantType authorizationGrantType;
+	private final AuthorizationGrantType authorizationGrantType;
 
-	private OAuth2AuthorizationResponseType responseType;
+	private final OAuth2AuthorizationResponseType responseType;
 
-	private String clientId;
+	private final String clientId;
 
-	private String redirectUri;
+	private final String redirectUri;
 
-	private Set<String> scopes;
+	private final Set<String> scopes;
 
-	private String state;
+	private final String state;
 
-	private Map<String, Object> additionalParameters;
+	private final Map<String, Object> additionalParameters;
 
-	private String authorizationRequestUri;
+	private final String authorizationRequestUri;
 
-	private Map<String, Object> attributes;
+	private final Map<String, Object> attributes;
 
-	private OAuth2AuthorizationRequest() {
+	protected OAuth2AuthorizationRequest(AbstractBuilder<?, ?> builder) {
+		Assert.hasText(builder.authorizationUri, "authorizationUri cannot be empty");
+		Assert.hasText(builder.clientId, "clientId cannot be empty");
+		this.authorizationUri = builder.authorizationUri;
+		this.authorizationGrantType = builder.authorizationGrantType;
+		this.responseType = builder.responseType;
+		this.clientId = builder.clientId;
+		this.redirectUri = builder.redirectUri;
+		this.scopes = Collections.unmodifiableSet(
+				CollectionUtils.isEmpty(builder.scopes) ? Collections.emptySet() : new LinkedHashSet<>(builder.scopes));
+		this.state = builder.state;
+		this.additionalParameters = Collections.unmodifiableMap(builder.additionalParameters);
+		this.authorizationRequestUri = StringUtils.hasText(builder.authorizationRequestUri)
+				? builder.authorizationRequestUri : builder.buildAuthorizationRequestUri();
+		this.attributes = Collections.unmodifiableMap(builder.attributes);
 	}
 
 	/**
@@ -185,7 +201,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 	 * @return the {@link Builder}
 	 */
 	public static Builder authorizationCode() {
-		return new Builder(AuthorizationGrantType.AUTHORIZATION_CODE);
+		return new Builder();
 	}
 
 	@Override
@@ -226,7 +242,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 	public static Builder from(OAuth2AuthorizationRequest authorizationRequest) {
 		Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
 		// @formatter:off
-		return new Builder(authorizationRequest.getGrantType())
+		return new Builder()
 				.authorizationUri(authorizationRequest.getAuthorizationUri())
 				.clientId(authorizationRequest.getClientId())
 				.redirectUri(authorizationRequest.getRedirectUri())
@@ -240,13 +256,32 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 	/**
 	 * A builder for {@link OAuth2AuthorizationRequest}.
 	 */
-	public static final class Builder {
+	public static class Builder extends AbstractBuilder<OAuth2AuthorizationRequest, Builder> {
+
+		/**
+		 * Builds a new {@link OAuth2AuthorizationRequest}.
+		 * @return a {@link OAuth2AuthorizationRequest}
+		 */
+		@Override
+		public OAuth2AuthorizationRequest build() {
+			return new OAuth2AuthorizationRequest(this);
+		}
+
+	}
+
+	/**
+	 * A builder for subclasses of {@link OAuth2AuthorizationRequest}.
+	 *
+	 * @param <T> the type of authorization request
+	 * @param <B> the type of the builder
+	 */
+	protected abstract static class AbstractBuilder<T extends OAuth2AuthorizationRequest, B extends AbstractBuilder<T, B>> {
 
 		private String authorizationUri;
 
-		private AuthorizationGrantType authorizationGrantType;
+		private final AuthorizationGrantType authorizationGrantType = AuthorizationGrantType.AUTHORIZATION_CODE;
 
-		private OAuth2AuthorizationResponseType responseType;
+		private final OAuth2AuthorizationResponseType responseType = OAuth2AuthorizationResponseType.CODE;
 
 		private String clientId;
 
@@ -269,12 +304,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 
 		private final DefaultUriBuilderFactory uriBuilderFactory;
 
-		private Builder(AuthorizationGrantType authorizationGrantType) {
-			Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
-			this.authorizationGrantType = authorizationGrantType;
-			if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationGrantType)) {
-				this.responseType = OAuth2AuthorizationResponseType.CODE;
-			}
+		protected AbstractBuilder() {
 			this.uriBuilderFactory = new DefaultUriBuilderFactory();
 			// The supplied authorizationUri may contain encoded parameters
 			// so disable encoding in UriBuilder and instead apply encoding within this
@@ -282,78 +312,85 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 			this.uriBuilderFactory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.NONE);
 		}
 
+		@SuppressWarnings("unchecked")
+		protected final B getThis() {
+			// avoid unchecked casts in subclasses by using "getThis()" instead of "(B)
+			// this"
+			return (B) this;
+		}
+
 		/**
 		 * Sets the uri for the authorization endpoint.
 		 * @param authorizationUri the uri for the authorization endpoint
-		 * @return the {@link Builder}
+		 * @return the {@link AbstractBuilder}
 		 */
-		public Builder authorizationUri(String authorizationUri) {
+		public B authorizationUri(String authorizationUri) {
 			this.authorizationUri = authorizationUri;
-			return this;
+			return getThis();
 		}
 
 		/**
 		 * Sets the client identifier.
 		 * @param clientId the client identifier
-		 * @return the {@link Builder}
+		 * @return the {@link AbstractBuilder}
 		 */
-		public Builder clientId(String clientId) {
+		public B clientId(String clientId) {
 			this.clientId = clientId;
-			return this;
+			return getThis();
 		}
 
 		/**
 		 * Sets the uri for the redirection endpoint.
 		 * @param redirectUri the uri for the redirection endpoint
-		 * @return the {@link Builder}
+		 * @return the {@link AbstractBuilder}
 		 */
-		public Builder redirectUri(String redirectUri) {
+		public B redirectUri(String redirectUri) {
 			this.redirectUri = redirectUri;
-			return this;
+			return getThis();
 		}
 
 		/**
 		 * Sets the scope(s).
 		 * @param scope the scope(s)
-		 * @return the {@link Builder}
+		 * @return the {@link AbstractBuilder}
 		 */
-		public Builder scope(String... scope) {
+		public B scope(String... scope) {
 			if (scope != null && scope.length > 0) {
 				return scopes(new LinkedHashSet<>(Arrays.asList(scope)));
 			}
-			return this;
+			return getThis();
 		}
 
 		/**
 		 * Sets the scope(s).
 		 * @param scopes the scope(s)
-		 * @return the {@link Builder}
+		 * @return the {@link AbstractBuilder}
 		 */
-		public Builder scopes(Set<String> scopes) {
+		public B scopes(Set<String> scopes) {
 			this.scopes = scopes;
-			return this;
+			return getThis();
 		}
 
 		/**
 		 * Sets the state.
 		 * @param state the state
-		 * @return the {@link Builder}
+		 * @return the {@link AbstractBuilder}
 		 */
-		public Builder state(String state) {
+		public B state(String state) {
 			this.state = state;
-			return this;
+			return getThis();
 		}
 
 		/**
 		 * Sets the additional parameter(s) used in the request.
 		 * @param additionalParameters the additional parameter(s) used in the request
-		 * @return the {@link Builder}
+		 * @return the {@link AbstractBuilder}
 		 */
-		public Builder additionalParameters(Map<String, Object> additionalParameters) {
+		public B additionalParameters(Map<String, Object> additionalParameters) {
 			if (!CollectionUtils.isEmpty(additionalParameters)) {
 				this.additionalParameters.putAll(additionalParameters);
 			}
-			return this;
+			return getThis();
 		}
 
 		/**
@@ -361,52 +398,55 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		 * allowing the ability to add, replace, or remove.
 		 * @param additionalParametersConsumer a {@code Consumer} of the additional
 		 * parameters
+		 * @return the {@link AbstractBuilder}
 		 * @since 5.3
 		 */
-		public Builder additionalParameters(Consumer<Map<String, Object>> additionalParametersConsumer) {
+		public B additionalParameters(Consumer<Map<String, Object>> additionalParametersConsumer) {
 			if (additionalParametersConsumer != null) {
 				additionalParametersConsumer.accept(this.additionalParameters);
 			}
-			return this;
+			return getThis();
 		}
 
 		/**
 		 * A {@code Consumer} to be provided access to all the parameters allowing the
 		 * ability to add, replace, or remove.
 		 * @param parametersConsumer a {@code Consumer} of all the parameters
+		 * @return the {@link AbstractBuilder}
 		 * @since 5.3
 		 */
-		public Builder parameters(Consumer<Map<String, Object>> parametersConsumer) {
+		public B parameters(Consumer<Map<String, Object>> parametersConsumer) {
 			if (parametersConsumer != null) {
 				this.parametersConsumer = parametersConsumer;
 			}
-			return this;
+			return getThis();
 		}
 
 		/**
 		 * Sets the attributes associated to the request.
 		 * @param attributes the attributes associated to the request
-		 * @return the {@link Builder}
+		 * @return the {@link AbstractBuilder}
 		 * @since 5.2
 		 */
-		public Builder attributes(Map<String, Object> attributes) {
+		public B attributes(Map<String, Object> attributes) {
 			if (!CollectionUtils.isEmpty(attributes)) {
 				this.attributes.putAll(attributes);
 			}
-			return this;
+			return getThis();
 		}
 
 		/**
 		 * A {@code Consumer} to be provided access to the attribute(s) allowing the
 		 * ability to add, replace, or remove.
 		 * @param attributesConsumer a {@code Consumer} of the attribute(s)
+		 * @return the {@link AbstractBuilder}
 		 * @since 5.3
 		 */
-		public Builder attributes(Consumer<Map<String, Object>> attributesConsumer) {
+		public B attributes(Consumer<Map<String, Object>> attributesConsumer) {
 			if (attributesConsumer != null) {
 				attributesConsumer.accept(this.attributes);
 			}
-			return this;
+			return getThis();
 		}
 
 		/**
@@ -418,12 +458,12 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		 * {@code application/x-www-form-urlencoded} MIME format.
 		 * @param authorizationRequestUri the {@code URI} string representation of the
 		 * OAuth 2.0 Authorization Request
-		 * @return the {@link Builder}
+		 * @return the {@link AbstractBuilder}
 		 * @since 5.1
 		 */
-		public Builder authorizationRequestUri(String authorizationRequestUri) {
+		public B authorizationRequestUri(String authorizationRequestUri) {
 			this.authorizationRequestUri = authorizationRequestUri;
-			return this;
+			return getThis();
 		}
 
 		/**
@@ -431,37 +471,17 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 		 * OAuth 2.0 Authorization Request allowing for further customizations.
 		 * @param authorizationRequestUriFunction a {@code Function} to be provided a
 		 * {@code UriBuilder} representation of the OAuth 2.0 Authorization Request
+		 * @return the {@link AbstractBuilder}
 		 * @since 5.3
 		 */
-		public Builder authorizationRequestUri(Function<UriBuilder, URI> authorizationRequestUriFunction) {
+		public B authorizationRequestUri(Function<UriBuilder, URI> authorizationRequestUriFunction) {
 			if (authorizationRequestUriFunction != null) {
 				this.authorizationRequestUriFunction = authorizationRequestUriFunction;
 			}
-			return this;
+			return getThis();
 		}
 
-		/**
-		 * Builds a new {@link OAuth2AuthorizationRequest}.
-		 * @return a {@link OAuth2AuthorizationRequest}
-		 */
-		public OAuth2AuthorizationRequest build() {
-			Assert.hasText(this.authorizationUri, "authorizationUri cannot be empty");
-			Assert.hasText(this.clientId, "clientId cannot be empty");
-			OAuth2AuthorizationRequest authorizationRequest = new OAuth2AuthorizationRequest();
-			authorizationRequest.authorizationUri = this.authorizationUri;
-			authorizationRequest.authorizationGrantType = this.authorizationGrantType;
-			authorizationRequest.responseType = this.responseType;
-			authorizationRequest.clientId = this.clientId;
-			authorizationRequest.redirectUri = this.redirectUri;
-			authorizationRequest.state = this.state;
-			authorizationRequest.scopes = Collections.unmodifiableSet(
-					CollectionUtils.isEmpty(this.scopes) ? Collections.emptySet() : new LinkedHashSet<>(this.scopes));
-			authorizationRequest.additionalParameters = Collections.unmodifiableMap(this.additionalParameters);
-			authorizationRequest.attributes = Collections.unmodifiableMap(this.attributes);
-			authorizationRequest.authorizationRequestUri = StringUtils.hasText(this.authorizationRequestUri)
-					? this.authorizationRequestUri : this.buildAuthorizationRequestUri();
-			return authorizationRequest;
-		}
+		public abstract T build();
 
 		private String buildAuthorizationRequestUri() {
 			Map<String, Object> parameters = getParameters(); // Not encoded
@@ -486,7 +506,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
 			return this.authorizationRequestUriFunction.apply(uriBuilder).toString();
 		}
 
-		private Map<String, Object> getParameters() {
+		protected Map<String, Object> getParameters() {
 			Map<String, Object> parameters = new LinkedHashMap<>();
 			parameters.put(OAuth2ParameterNames.RESPONSE_TYPE, this.responseType.getValue());
 			parameters.put(OAuth2ParameterNames.CLIENT_ID, this.clientId);

+ 36 - 0
oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java

@@ -388,4 +388,40 @@ public class OAuth2AuthorizationRequestTests {
 		assertThat(authorizationRequest1HashCode).isEqualTo(authorizationRequest2HashCode);
 	}
 
+	@Test
+	public void buildWhenExtendedTypeAndAllValuesProvidedThenAllValuesAreSet() {
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put("param1", "value1");
+		additionalParameters.put("param2", "value2");
+		Map<String, Object> attributes = new HashMap<>();
+		attributes.put("attribute1", "value1");
+		attributes.put("attribute2", "value2");
+		// @formatter:off
+		TestOidcAuthorizationRequest oidcAuthorizationRequest = TestOidcAuthorizationRequest.builder()
+				.authorizationUri(AUTHORIZATION_URI)
+				.clientId(CLIENT_ID)
+				.redirectUri(REDIRECT_URI)
+				.scopes(SCOPES)
+				.state(STATE)
+				.additionalParameters(additionalParameters)
+				.attributes(attributes)
+				.nonce("nonce1234")
+				.build();
+		// @formatter:on
+		assertThat(oidcAuthorizationRequest.getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI);
+		assertThat(oidcAuthorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
+		assertThat(oidcAuthorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
+		assertThat(oidcAuthorizationRequest.getClientId()).isEqualTo(CLIENT_ID);
+		assertThat(oidcAuthorizationRequest.getRedirectUri()).isEqualTo(REDIRECT_URI);
+		assertThat(oidcAuthorizationRequest.getScopes()).isEqualTo(SCOPES);
+		assertThat(oidcAuthorizationRequest.getState()).isEqualTo(STATE);
+		assertThat(oidcAuthorizationRequest.getAdditionalParameters()).isEqualTo(additionalParameters);
+		assertThat(oidcAuthorizationRequest.getAttributes()).isEqualTo(attributes);
+		assertThat(oidcAuthorizationRequest.getNonce()).isEqualTo("nonce1234");
+		assertThat(oidcAuthorizationRequest.getAuthorizationRequestUri())
+			.isEqualTo("https://provider.com/oauth2/authorize?" + "response_type=code&client_id=client-id&"
+					+ "scope=scope1%20scope2&state=state&"
+					+ "redirect_uri=https://example.com&param1=value1&param2=value2&nonce=nonce1234");
+	}
+
 }

+ 68 - 0
oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOidcAuthorizationRequest.java

@@ -0,0 +1,68 @@
+/*
+ * Copyright 2004-present the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.core.endpoint;
+
+import java.util.Map;
+
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
+
+/**
+ * @author Joe Grandja
+ */
+public class TestOidcAuthorizationRequest extends OAuth2AuthorizationRequest {
+
+	private final String nonce;
+
+	protected TestOidcAuthorizationRequest(Builder builder) {
+		super(builder);
+		this.nonce = builder.nonce;
+	}
+
+	public String getNonce() {
+		return this.nonce;
+	}
+
+	public static Builder builder() {
+		return new Builder();
+	}
+
+	public static class Builder extends AbstractBuilder<TestOidcAuthorizationRequest, Builder> {
+
+		private String nonce;
+
+		public Builder nonce(String nonce) {
+			this.nonce = nonce;
+			return this;
+		}
+
+		@Override
+		public TestOidcAuthorizationRequest build() {
+			return new TestOidcAuthorizationRequest(this);
+		}
+
+		@Override
+		protected Map<String, Object> getParameters() {
+			Map<String, Object> parameters = super.getParameters();
+			if (this.nonce != null) {
+				parameters.put(OidcParameterNames.NONCE, this.nonce);
+			}
+			return parameters;
+		}
+
+	}
+
+}