瀏覽代碼

Allow Jwt assertion to be resolved

Closes gh-9812
Joe Grandja 3 年之前
父節點
當前提交
525f40490c

+ 6 - 0
docs/modules/ROOT/pages/reactive/oauth2/client/authorization-grants.adoc

@@ -1098,3 +1098,9 @@ class OAuth2ResourceServerController {
 }
 ----
 ====
+
+[NOTE]
+`JwtBearerReactiveOAuth2AuthorizedClientProvider` resolves the `Jwt` assertion via `OAuth2AuthorizationContext.getPrincipal().getPrincipal()` by default, hence the use of `JwtAuthenticationToken` in the preceding example.
+
+[TIP]
+If you need to resolve the `Jwt` assertion from a different source, you can provide `JwtBearerReactiveOAuth2AuthorizedClientProvider.setJwtAssertionResolver()` with a custom `Function<OAuth2AuthorizationContext, Mono<Jwt>>`.

+ 6 - 0
docs/modules/ROOT/pages/servlet/oauth2/client/authorization-grants.adoc

@@ -1352,3 +1352,9 @@ class OAuth2ResourceServerController {
 }
 ----
 ====
+
+[NOTE]
+`JwtBearerOAuth2AuthorizedClientProvider` resolves the `Jwt` assertion via `OAuth2AuthorizationContext.getPrincipal().getPrincipal()` by default, hence the use of `JwtAuthenticationToken` in the preceding example.
+
+[TIP]
+If you need to resolve the `Jwt` assertion from a different source, you can provide `JwtBearerOAuth2AuthorizedClientProvider.setJwtAssertionResolver()` with a custom `Function<OAuth2AuthorizationContext, Jwt>`.

+ 24 - 3
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client;
 import java.time.Clock;
 import java.time.Duration;
 import java.time.Instant;
+import java.util.function.Function;
 
 import org.springframework.lang.Nullable;
 import org.springframework.security.oauth2.client.endpoint.DefaultJwtBearerTokenResponseClient;
@@ -45,6 +46,8 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
 
 	private OAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = new DefaultJwtBearerTokenResponseClient();
 
+	private Function<OAuth2AuthorizationContext, Jwt> jwtAssertionResolver = this::resolveJwtAssertion;
+
 	private Duration clockSkew = Duration.ofSeconds(60);
 
 	private Clock clock = Clock.systemUTC();
@@ -75,10 +78,10 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
 			// need for re-authorization
 			return null;
 		}
-		if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
+		Jwt jwt = this.jwtAssertionResolver.apply(context);
+		if (jwt == null) {
 			return null;
 		}
-		Jwt jwt = (Jwt) context.getPrincipal().getPrincipal();
 		// As per spec, in section 4.1 Using Assertions as Authorization Grants
 		// https://tools.ietf.org/html/rfc7521#section-4.1
 		//
@@ -97,6 +100,13 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
 				tokenResponse.getAccessToken());
 	}
 
+	private Jwt resolveJwtAssertion(OAuth2AuthorizationContext context) {
+		if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
+			return null;
+		}
+		return (Jwt) context.getPrincipal().getPrincipal();
+	}
+
 	private OAuth2AccessTokenResponse getTokenResponse(ClientRegistration clientRegistration,
 			JwtBearerGrantRequest jwtBearerGrantRequest) {
 		try {
@@ -123,6 +133,17 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
 		this.accessTokenResponseClient = accessTokenResponseClient;
 	}
 
+	/**
+	 * Sets the resolver used for resolving the {@link Jwt} assertion.
+	 * @param jwtAssertionResolver the resolver used for resolving the {@link Jwt}
+	 * assertion
+	 * @since 5.7
+	 */
+	public void setJwtAssertionResolver(Function<OAuth2AuthorizationContext, Jwt> jwtAssertionResolver) {
+		Assert.notNull(jwtAssertionResolver, "jwtAssertionResolver cannot be null");
+		this.jwtAssertionResolver = jwtAssertionResolver;
+	}
+
 	/**
 	 * Sets the maximum acceptable clock skew, which is used when checking the
 	 * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is

+ 30 - 6
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client;
 import java.time.Clock;
 import java.time.Duration;
 import java.time.Instant;
+import java.util.function.Function;
 
 import reactor.core.publisher.Mono;
 
@@ -45,6 +46,8 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re
 
 	private ReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = new WebClientReactiveJwtBearerTokenResponseClient();
 
+	private Function<OAuth2AuthorizationContext, Mono<Jwt>> jwtAssertionResolver = this::resolveJwtAssertion;
+
 	private Duration clockSkew = Duration.ofSeconds(60);
 
 	private Clock clock = Clock.systemUTC();
@@ -74,10 +77,7 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re
 			// need for re-authorization
 			return Mono.empty();
 		}
-		if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
-			return Mono.empty();
-		}
-		Jwt jwt = (Jwt) context.getPrincipal().getPrincipal();
+
 		// As per spec, in section 4.1 Using Assertions as Authorization Grants
 		// https://tools.ietf.org/html/rfc7521#section-4.1
 		//
@@ -90,13 +90,26 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re
 		// issued with a reasonably short lifetime. Clients can refresh an
 		// expired access token by requesting a new one using the same
 		// assertion, if it is still valid, or with a new assertion.
-		return Mono.just(new JwtBearerGrantRequest(clientRegistration, jwt))
+
+		// @formatter:off
+		return this.jwtAssertionResolver.apply(context)
+				.map((jwt) -> new JwtBearerGrantRequest(clientRegistration, jwt))
 				.flatMap(this.accessTokenResponseClient::getTokenResponse)
 				.onErrorMap(OAuth2AuthorizationException.class,
 						(ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(),
 								ex))
 				.map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
 						tokenResponse.getAccessToken()));
+		// @formatter:on
+	}
+
+	private Mono<Jwt> resolveJwtAssertion(OAuth2AuthorizationContext context) {
+		// @formatter:off
+		return Mono.just(context)
+				.map((ctx) -> ctx.getPrincipal().getPrincipal())
+				.filter((principal) -> principal instanceof Jwt)
+				.cast(Jwt.class);
+		// @formatter:on
 	}
 
 	private boolean hasTokenExpired(OAuth2Token token) {
@@ -115,6 +128,17 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re
 		this.accessTokenResponseClient = accessTokenResponseClient;
 	}
 
+	/**
+	 * Sets the resolver used for resolving the {@link Jwt} assertion.
+	 * @param jwtAssertionResolver the resolver used for resolving the {@link Jwt}
+	 * assertion
+	 * @since 5.7
+	 */
+	public void setJwtAssertionResolver(Function<OAuth2AuthorizationContext, Mono<Jwt>> jwtAssertionResolver) {
+		Assert.notNull(jwtAssertionResolver, "jwtAssertionResolver cannot be null");
+		this.jwtAssertionResolver = jwtAssertionResolver;
+	}
+
 	/**
 	 * Sets the maximum acceptable clock skew, which is used when checking the
 	 * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is

+ 33 - 3
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProviderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.client;
 
 import java.time.Duration;
 import java.time.Instant;
+import java.util.function.Function;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -42,6 +43,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 
 /**
  * Tests for {@link JwtBearerOAuth2AuthorizedClientProvider}.
@@ -87,6 +89,13 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
 				.withMessage("accessTokenResponseClient cannot be null");
 	}
 
+	@Test
+	public void setJwtAssertionResolverWhenNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.authorizedClientProvider.setJwtAssertionResolver(null))
+				.withMessage("jwtAssertionResolver cannot be null");
+	}
+
 	@Test
 	public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
 		// @formatter:off
@@ -198,7 +207,7 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
 	}
 
 	@Test
-	public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() {
+	public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtDoesNotResolveThenUnableToAuthorize() {
 		// @formatter:off
 		OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
 				.withClientRegistration(this.clientRegistration)
@@ -209,7 +218,7 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
 	}
 
 	@Test
-	public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize() {
+	public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtResolvesThenAuthorize() {
 		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
 		given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
 		// @formatter:off
@@ -224,4 +233,25 @@ public class JwtBearerOAuth2AuthorizedClientProviderTests {
 		assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
 	}
 
+	@Test
+	public void authorizeWhenCustomJwtAssertionResolverSetThenUsed() {
+		Function<OAuth2AuthorizationContext, Jwt> jwtAssertionResolver = mock(Function.class);
+		given(jwtAssertionResolver.apply(any())).willReturn(this.jwtAssertion);
+		this.authorizedClientProvider.setJwtAssertionResolver(jwtAssertionResolver);
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
+		// @formatter:off
+		TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password");
+		OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
+				.withClientRegistration(this.clientRegistration)
+				.principal(principal)
+				.build();
+		// @formatter:on
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
+		verify(jwtAssertionResolver).apply(any());
+		assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
+		assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName());
+		assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
+	}
+
 }

+ 32 - 3
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProviderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client;
 import java.time.Clock;
 import java.time.Duration;
 import java.time.Instant;
+import java.util.function.Function;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -93,6 +94,13 @@ public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests {
 				.withMessage("accessTokenResponseClient cannot be null");
 	}
 
+	@Test
+	public void setJwtAssertionResolverWhenNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.authorizedClientProvider.setJwtAssertionResolver(null))
+				.withMessage("jwtAssertionResolver cannot be null");
+	}
+
 	@Test
 	public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
 		// @formatter:off
@@ -222,7 +230,7 @@ public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests {
 	}
 
 	@Test
-	public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() {
+	public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtDoesNotResolveThenUnableToAuthorize() {
 		// @formatter:off
 		OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
 				.withClientRegistration(this.clientRegistration)
@@ -251,7 +259,7 @@ public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests {
 	}
 
 	@Test
-	public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize() {
+	public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtResolvesThenAuthorize() {
 		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
 		given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
 		// @formatter:off
@@ -266,4 +274,25 @@ public class JwtBearerReactiveOAuth2AuthorizedClientProviderTests {
 		assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
 	}
 
+	@Test
+	public void authorizeWhenCustomJwtAssertionResolverSetThenUsed() {
+		Function<OAuth2AuthorizationContext, Mono<Jwt>> jwtAssertionResolver = mock(Function.class);
+		given(jwtAssertionResolver.apply(any())).willReturn(Mono.just(this.jwtAssertion));
+		this.authorizedClientProvider.setJwtAssertionResolver(jwtAssertionResolver);
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
+		// @formatter:off
+		TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password");
+		OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
+				.withClientRegistration(this.clientRegistration)
+				.principal(principal)
+				.build();
+		// @formatter:on
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
+		verify(jwtAssertionResolver).apply(any());
+		assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
+		assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName());
+		assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
+	}
+
 }