Browse Source

Add support for OpenID Connect 1.0 prompt=none parameter

Closes gh-501
Joe Grandja 1 year ago
parent
commit
19dfcd4ba9

+ 51 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

@@ -16,7 +16,10 @@
 package org.springframework.security.oauth2.server.authorization.authentication;
 
 import java.security.Principal;
+import java.util.Arrays;
 import java.util.Base64;
+import java.util.Collections;
+import java.util.HashSet;
 import java.util.Set;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
@@ -70,6 +73,9 @@ import org.springframework.util.StringUtils;
  * @see <a target="_blank" href=
  * "https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1">Section 4.1.1
  * Authorization Request</a>
+ * @see <a target="_blank" href=
+ * "https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest">Section 3.1.2.1
+ * Authentication Request</a>
  */
 public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implements AuthenticationProvider {
 
@@ -158,6 +164,22 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 					authorizationCodeRequestAuthentication, registeredClient, null);
 		}
 
+		// prompt (OPTIONAL for OpenID Connect 1.0 Authentication Request)
+		Set<String> promptValues = Collections.emptySet();
+		if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID)) {
+			String prompt = (String) authorizationCodeRequestAuthentication.getAdditionalParameters().get("prompt");
+			if (StringUtils.hasText(prompt)) {
+				promptValues = new HashSet<>(Arrays.asList(StringUtils.delimitedListToStringArray(prompt, " ")));
+				if (promptValues.contains(OidcPrompts.NONE)) {
+					if (promptValues.contains(OidcPrompts.LOGIN) || promptValues.contains(OidcPrompts.CONSENT)
+							|| promptValues.contains(OidcPrompts.SELECT_ACCOUNT)) {
+						throwError(OAuth2ErrorCodes.INVALID_REQUEST, "prompt", authorizationCodeRequestAuthentication,
+								registeredClient);
+					}
+				}
+			}
+		}
+
 		if (this.logger.isTraceEnabled()) {
 			this.logger.trace("Validated authorization code request parameters");
 		}
@@ -168,6 +190,11 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 
 		Authentication principal = (Authentication) authorizationCodeRequestAuthentication.getPrincipal();
 		if (!isPrincipalAuthenticated(principal)) {
+			if (promptValues.contains(OidcPrompts.NONE)) {
+				// Return an error instead of displaying the login page (via the
+				// configured AuthenticationEntryPoint)
+				throwError("login_required", "prompt", authorizationCodeRequestAuthentication, registeredClient);
+			}
 			if (this.logger.isTraceEnabled()) {
 				this.logger.trace("Did not authenticate authorization code request since principal not authenticated");
 			}
@@ -192,6 +219,11 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 		}
 
 		if (this.authorizationConsentRequired.test(authenticationContextBuilder.build())) {
+			if (promptValues.contains(OidcPrompts.NONE)) {
+				// Return an error instead of displaying the consent page
+				throwError("consent_required", "prompt", authorizationCodeRequestAuthentication, registeredClient);
+			}
+
 			String state = DEFAULT_STATE_GENERATOR.generateKey();
 			OAuth2Authorization authorization = authorizationBuilder(registeredClient, principal, authorizationRequest)
 				.attribute(OAuth2ParameterNames.STATE, state)
@@ -425,4 +457,23 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 		return null;
 	}
 
+	/*
+	 * The values defined for the "prompt" parameter for the OpenID Connect 1.0
+	 * Authentication Request.
+	 */
+	private static final class OidcPrompts {
+
+		private static final String NONE = "none";
+
+		private static final String LOGIN = "login";
+
+		private static final String CONSENT = "consent";
+
+		private static final String SELECT_ACCOUNT = "select_account";
+
+		private OidcPrompts() {
+		}
+
+	}
+
 }

+ 10 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2023 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -39,6 +39,7 @@ import org.springframework.security.oauth2.server.authorization.web.OAuth2Author
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.util.matcher.AndRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
+import org.springframework.util.CollectionUtils;
 import org.springframework.util.MultiValueMap;
 import org.springframework.util.StringUtils;
 
@@ -131,6 +132,14 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationConverter impleme
 			throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
 		}
 
+		// prompt (OPTIONAL for OpenID Connect 1.0 Authentication Request)
+		if (!CollectionUtils.isEmpty(scopes) && scopes.contains(OidcScopes.OPENID)) {
+			String prompt = parameters.getFirst("prompt");
+			if (StringUtils.hasText(prompt) && parameters.get("prompt").size() != 1) {
+				throwError(OAuth2ErrorCodes.INVALID_REQUEST, "prompt");
+			}
+		}
+
 		Map<String, Object> additionalParameters = new HashMap<>();
 		parameters.forEach((key, value) -> {
 			if (!key.equals(OAuth2ParameterNames.RESPONSE_TYPE) && !key.equals(OAuth2ParameterNames.CLIENT_ID)

+ 73 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java

@@ -366,6 +366,59 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
 					authentication.getRedirectUri()));
 	}
 
+	@Test
+	public void authenticateWhenAuthenticationRequestWithPromptNoneLoginThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
+		assertWhenAuthenticationRequestWithPromptThenThrowOAuth2AuthorizationCodeRequestAuthenticationException(
+				"none login");
+	}
+
+	@Test
+	public void authenticateWhenAuthenticationRequestWithPromptNoneConsentThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
+		assertWhenAuthenticationRequestWithPromptThenThrowOAuth2AuthorizationCodeRequestAuthenticationException(
+				"none consent");
+	}
+
+	@Test
+	public void authenticateWhenAuthenticationRequestWithPromptNoneSelectAccountThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
+		assertWhenAuthenticationRequestWithPromptThenThrowOAuth2AuthorizationCodeRequestAuthenticationException(
+				"none select_account");
+	}
+
+	private void assertWhenAuthenticationRequestWithPromptThenThrowOAuth2AuthorizationCodeRequestAuthenticationException(
+			String prompt) {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
+		given(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+			.willReturn(registeredClient);
+		String redirectUri = registeredClient.getRedirectUris().toArray(new String[0])[2];
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put("prompt", prompt);
+		OAuth2AuthorizationCodeRequestAuthenticationToken authentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
+				AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, redirectUri, STATE,
+				registeredClient.getScopes(), additionalParameters);
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+			.isInstanceOf(OAuth2AuthorizationCodeRequestAuthenticationException.class)
+			.satisfies((ex) -> assertAuthenticationException((OAuth2AuthorizationCodeRequestAuthenticationException) ex,
+					OAuth2ErrorCodes.INVALID_REQUEST, "prompt", authentication.getRedirectUri()));
+	}
+
+	@Test
+	public void authenticateWhenPrincipalNotAuthenticatedAndPromptNoneThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
+		given(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+			.willReturn(registeredClient);
+		this.principal.setAuthenticated(false);
+		String redirectUri = registeredClient.getRedirectUris().toArray(new String[0])[2];
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put("prompt", "none");
+		OAuth2AuthorizationCodeRequestAuthenticationToken authentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
+				AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, redirectUri, STATE,
+				registeredClient.getScopes(), additionalParameters);
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+			.isInstanceOf(OAuth2AuthorizationCodeRequestAuthenticationException.class)
+			.satisfies((ex) -> assertAuthenticationException((OAuth2AuthorizationCodeRequestAuthenticationException) ex,
+					"login_required", "prompt", authentication.getRedirectUri()));
+	}
+
 	@Test
 	public void authenticateWhenPrincipalNotAuthenticatedThenReturnAuthorizationCodeRequest() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@@ -385,6 +438,26 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
 		assertThat(authenticationResult.isAuthenticated()).isFalse();
 	}
 
+	@Test
+	public void authenticateWhenRequireAuthorizationConsentAndPromptNoneThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+			.scope(OidcScopes.OPENID)
+			.clientSettings(ClientSettings.builder().requireAuthorizationConsent(true).build())
+			.build();
+		given(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+			.willReturn(registeredClient);
+		String redirectUri = registeredClient.getRedirectUris().toArray(new String[0])[2];
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put("prompt", "none");
+		OAuth2AuthorizationCodeRequestAuthenticationToken authentication = new OAuth2AuthorizationCodeRequestAuthenticationToken(
+				AUTHORIZATION_URI, registeredClient.getClientId(), this.principal, redirectUri, STATE,
+				registeredClient.getScopes(), additionalParameters);
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+			.isInstanceOf(OAuth2AuthorizationCodeRequestAuthenticationException.class)
+			.satisfies((ex) -> assertAuthenticationException((OAuth2AuthorizationCodeRequestAuthenticationException) ex,
+					"consent_required", "prompt", authentication.getRedirectUri()));
+	}
+
 	@Test
 	public void authenticateWhenRequireAuthorizationConsentThenReturnAuthorizationConsent() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()

+ 16 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2023 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -288,6 +288,21 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				});
 	}
 
+	@Test
+	public void doFilterWhenAuthenticationRequestMultiplePromptThenInvalidRequestError() throws Exception {
+		// Setup OpenID Connect request
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scopes((scopes) -> {
+			scopes.clear();
+			scopes.add(OidcScopes.OPENID);
+		}).build();
+		doFilterWhenAuthorizationRequestInvalidParameterThenError(registeredClient, "prompt",
+				OAuth2ErrorCodes.INVALID_REQUEST, (request) -> {
+					request.addParameter("prompt", "none");
+					request.addParameter("prompt", "login");
+					updateQueryString(request);
+				});
+	}
+
 	@Test
 	public void doFilterWhenAuthorizationRequestAuthenticationExceptionThenErrorResponse() throws Exception {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().redirectUris((redirectUris) -> {