Эх сурвалжийг харах

Hash the sid claim in the ID Token

Closes gh-1207
Joe Grandja 2 жил өмнө
parent
commit
a70783e6e7

+ 19 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -15,8 +15,12 @@
  */
 package org.springframework.security.oauth2.server.authorization.authentication;
 
+import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
 import java.security.Principal;
 import java.util.ArrayList;
+import java.util.Base64;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -234,6 +238,15 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 		if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) {
 			SessionInformation sessionInformation = getSessionInformation(principal);
 			if (sessionInformation != null) {
+				try {
+					// Compute (and use) hash for Session ID
+					sessionInformation = new SessionInformation(sessionInformation.getPrincipal(),
+							createHash(sessionInformation.getSessionId()), sessionInformation.getLastRequest());
+				} catch (NoSuchAlgorithmException ex) {
+					OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
+							"Failed to compute hash for Session ID.", ERROR_URI);
+					throw new OAuth2AuthenticationException(error);
+				}
 				tokenContextBuilder.put(SessionInformation.class, sessionInformation);
 			}
 			// @formatter:off
@@ -319,4 +332,10 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 		return sessionInformation;
 	}
 
+	private static String createHash(String value) throws NoSuchAlgorithmException {
+		MessageDigest md = MessageDigest.getInstance("SHA-256");
+		byte[] digest = md.digest(value.getBytes(StandardCharsets.US_ASCII));
+		return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
+	}
+
 }

+ 19 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java

@@ -15,6 +15,10 @@
  */
 package org.springframework.security.oauth2.server.authorization.oidc.authentication;
 
+import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
+import java.util.Base64;
 import java.util.List;
 
 import org.apache.commons.logging.Log;
@@ -137,9 +141,18 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
 				SessionInformation sessionInformation = findSessionInformation(
 						userPrincipal, oidcLogoutAuthentication.getSessionId());
 				if (sessionInformation != null) {
+					String sessionIdHash;
+					try {
+						sessionIdHash = createHash(sessionInformation.getSessionId());
+					} catch (NoSuchAlgorithmException ex) {
+						OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
+								"Failed to compute hash for Session ID.", null);
+						throw new OAuth2AuthenticationException(error);
+					}
+
 					String sidClaim = idToken.getClaim("sid");
 					if (!StringUtils.hasText(sidClaim) ||
-							!sidClaim.equals(sessionInformation.getSessionId())) {
+							!sidClaim.equals(sessionIdHash)) {
 						throwError(OAuth2ErrorCodes.INVALID_TOKEN, "sid");
 					}
 				}
@@ -182,4 +195,9 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
 		throw new OAuth2AuthenticationException(error);
 	}
 
+	private static String createHash(String value) throws NoSuchAlgorithmException {
+		MessageDigest md = MessageDigest.getInstance("SHA-256");
+		byte[] digest = md.digest(value.getBytes(StandardCharsets.US_ASCII));
+		return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
+	}
 }

+ 14 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -15,11 +15,15 @@
  */
 package org.springframework.security.oauth2.server.authorization.authentication;
 
+import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
 import java.security.Principal;
 import java.time.Duration;
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
 import java.util.ArrayList;
+import java.util.Base64;
 import java.util.Date;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -460,7 +464,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 	}
 
 	@Test
-	public void authenticateWhenValidCodeAndAuthenticationRequestThenReturnIdToken() {
+	public void authenticateWhenValidCodeAndAuthenticationRequestThenReturnIdToken() throws Exception {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
 		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
 				"code", Instant.now(), Instant.now().plusSeconds(120));
@@ -522,7 +526,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		assertThat(idTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
 		assertThat(idTokenContext.<OAuth2AuthorizationGrantAuthenticationToken>getAuthorizationGrant()).isEqualTo(authentication);
 		SessionInformation sessionInformation = idTokenContext.get(SessionInformation.class);
-		assertThat(sessionInformation).isNotNull().isSameAs(expectedSession);
+		assertThat(sessionInformation).isNotNull();
+		assertThat(sessionInformation.getSessionId()).isEqualTo(createHash(expectedSession.getSessionId()));
 		assertThat(idTokenContext.getJwsHeader()).isNotNull();
 		assertThat(idTokenContext.getClaims()).isNotNull();
 
@@ -710,4 +715,11 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 				.expiresAt(expiresAt)
 				.build();
 	}
+
+	private static String createHash(String value) throws NoSuchAlgorithmException {
+		MessageDigest md = MessageDigest.getInstance("SHA-256");
+		byte[] digest = md.digest(value.getBytes(StandardCharsets.US_ASCII));
+		return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
+	}
+
 }

+ 12 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java

@@ -15,8 +15,12 @@
  */
 package org.springframework.security.oauth2.server.authorization.oidc.authentication;
 
+import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
+import java.util.Base64;
 import java.util.Collections;
 import java.util.Date;
 import java.util.List;
@@ -443,7 +447,7 @@ public class OidcLogoutAuthenticationProviderTests {
 	}
 
 	@Test
-	public void authenticateWhenValidIdTokenThenAuthenticated() {
+	public void authenticateWhenValidIdTokenThenAuthenticated() throws Exception {
 		TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials");
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		String sessionId = "session-1";
@@ -453,7 +457,7 @@ public class OidcLogoutAuthenticationProviderTests {
 				.audience(Collections.singleton(registeredClient.getClientId()))
 				.issuedAt(Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS))
 				.expiresAt(Instant.now().plusSeconds(60).truncatedTo(ChronoUnit.MILLIS))
-				.claim("sid", sessionId)
+				.claim("sid", createHash(sessionId))
 				.build();
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
 				.principalName(principal.getName())
@@ -496,4 +500,10 @@ public class OidcLogoutAuthenticationProviderTests {
 		assertThat(authenticationResult.isAuthenticated()).isTrue();
 	}
 
+	private static String createHash(String value) throws NoSuchAlgorithmException {
+		MessageDigest md = MessageDigest.getInstance("SHA-256");
+		byte[] digest = md.digest(value.getBytes(StandardCharsets.US_ASCII));
+		return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
+	}
+
 }