Browse Source

Allow logout+jwt JWT type for reactive

The OIDC back-channel spec recommends using a logout token typ `logout+jwt`
(see [here](https://openid.net/specs/openid-connect-backchannel-1_0-final.html#LogoutToken).

Support of this type was recently added [on the servlet side]([on the Servlet side](https://github.com/spring-projects/spring-security/commit/9101bf1f7d5ad5d57ebfa2df78135704cdc19b23)), so back
porting the same on the reactive side to close the gap.

Closes gh-15702
Cedric Montfort 11 months ago
parent
commit
aceb5fa6bb

+ 30 - 5
config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutReactiveAuthenticationManager.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2023 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,6 +16,10 @@
 
 package org.springframework.security.config.web.server;
 
+import com.nimbusds.jose.JOSEObjectType;
+import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier;
+import com.nimbusds.jose.proc.JOSEObjectTypeVerifier;
+import com.nimbusds.jose.proc.JWKSecurityContext;
 import reactor.core.publisher.Mono;
 
 import org.springframework.security.authentication.AuthenticationProvider;
@@ -23,19 +27,22 @@ import org.springframework.security.authentication.AuthenticationServiceExceptio
 import org.springframework.security.authentication.ReactiveAuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
-import org.springframework.security.oauth2.client.oidc.authentication.ReactiveOidcIdTokenDecoderFactory;
+import org.springframework.security.oauth2.client.oidc.authentication.OidcIdTokenDecoderFactory;
 import org.springframework.security.oauth2.client.oidc.authentication.logout.OidcLogoutToken;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.converter.ClaimTypeConverter;
 import org.springframework.security.oauth2.jwt.BadJwtException;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtDecoder;
 import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
+import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder;
 import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
 import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
 import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
 
 /**
  * An {@link AuthenticationProvider} that authenticates an OIDC Logout Token; namely
@@ -61,9 +68,27 @@ final class OidcBackChannelLogoutReactiveAuthenticationManager implements Reacti
 	 * Construct an {@link OidcBackChannelLogoutReactiveAuthenticationManager}
 	 */
 	OidcBackChannelLogoutReactiveAuthenticationManager() {
-		ReactiveOidcIdTokenDecoderFactory logoutTokenDecoderFactory = new ReactiveOidcIdTokenDecoderFactory();
-		logoutTokenDecoderFactory.setJwtValidatorFactory(new DefaultOidcLogoutTokenValidatorFactory());
-		this.logoutTokenDecoderFactory = logoutTokenDecoderFactory;
+		DefaultOidcLogoutTokenValidatorFactory jwtValidator = new DefaultOidcLogoutTokenValidatorFactory();
+		this.logoutTokenDecoderFactory = (clientRegistration) -> {
+			String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
+			if (!StringUtils.hasText(jwkSetUri)) {
+				OAuth2Error oauth2Error = new OAuth2Error("missing_signature_verifier",
+						"Failed to find a Signature Verifier for Client Registration: '"
+								+ clientRegistration.getRegistrationId()
+								+ "'. Check to ensure you have configured the JwkSet URI.",
+						null);
+				throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+			}
+			JOSEObjectTypeVerifier<JWKSecurityContext> typeVerifier = new DefaultJOSEObjectTypeVerifier<>(null,
+					JOSEObjectType.JWT, new JOSEObjectType("logout+jwt"));
+			NimbusReactiveJwtDecoder decoder = NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri)
+				.jwtProcessorCustomizer((processor) -> processor.setJWSTypeVerifier(typeVerifier))
+				.build();
+			decoder.setJwtValidator(jwtValidator.apply(clientRegistration));
+			decoder.setClaimSetConverter(
+					new ClaimTypeConverter(OidcIdTokenDecoderFactory.createDefaultClaimTypeConverters()));
+			return decoder;
+		};
 	}
 
 	/**

+ 8 - 4
config/src/test/java/org/springframework/security/config/web/server/OidcLogoutSpecTests.java

@@ -75,6 +75,8 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
+import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
+import org.springframework.security.oauth2.jwt.JwsHeader;
 import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
 import org.springframework.security.oauth2.jwt.JwtEncoderParameters;
@@ -819,8 +821,9 @@ public class OidcLogoutSpecTests {
 		String logoutToken(@AuthenticationPrincipal OidcUser user) {
 			OidcLogoutToken token = TestOidcLogoutTokens.withUser(user)
 					.audience(List.of(this.registration.getClientId())).build();
-			JwtEncoderParameters parameters = JwtEncoderParameters
-					.from(JwtClaimsSet.builder().claims((claims) -> claims.putAll(token.getClaims())).build());
+			JwsHeader header = JwsHeader.with(SignatureAlgorithm.RS256).type("logout+jwt").build();
+			JwtClaimsSet claims = JwtClaimsSet.builder().claims((c) -> c.putAll(token.getClaims())).build();
+			JwtEncoderParameters parameters = JwtEncoderParameters.from(header, claims);
 			return this.encoder.encode(parameters).getTokenValue();
 		}
 
@@ -829,8 +832,9 @@ public class OidcLogoutSpecTests {
 			OidcLogoutToken token = TestOidcLogoutTokens.withUser(user)
 					.audience(List.of(this.registration.getClientId()))
 					.claims((claims) -> claims.remove(LogoutTokenClaimNames.SID)).build();
-			JwtEncoderParameters parameters = JwtEncoderParameters
-					.from(JwtClaimsSet.builder().claims((claims) -> claims.putAll(token.getClaims())).build());
+			JwsHeader header = JwsHeader.with(SignatureAlgorithm.RS256).type("JWT").build();
+			JwtClaimsSet claims = JwtClaimsSet.builder().claims((c) -> c.putAll(token.getClaims())).build();
+			JwtEncoderParameters parameters = JwtEncoderParameters.from(header, claims);
 			return this.encoder.encode(parameters).getTokenValue();
 		}
 	}