Browse Source

OAuth2AuthorizedClientManager implementation works outside of request

Fixes gh-6780
Joe Grandja 6 years ago
parent
commit
f7d03858f1

+ 142 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManager.java

@@ -0,0 +1,142 @@
+/*
+ * Copyright 2002-2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.springframework.lang.Nullable;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Function;
+
+/**
+ * An implementation of an {@link OAuth2AuthorizedClientManager}
+ * that is capable of operating outside of a {@code HttpServletRequest} context,
+ * e.g. in a scheduled/background thread and/or in the service-tier.
+ *
+ * @author Joe Grandja
+ * @since 5.2
+ * @see OAuth2AuthorizedClientManager
+ * @see OAuth2AuthorizedClientProvider
+ * @see OAuth2AuthorizedClientService
+ */
+public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager {
+	private final ClientRegistrationRepository clientRegistrationRepository;
+	private final OAuth2AuthorizedClientService authorizedClientService;
+	private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null;
+	private Function<OAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper = new DefaultContextAttributesMapper();
+
+	/**
+	 * Constructs an {@code AuthorizedClientServiceOAuth2AuthorizedClientManager} using the provided parameters.
+	 *
+	 * @param clientRegistrationRepository the repository of client registrations
+	 * @param authorizedClientService the authorized client service
+	 */
+	public AuthorizedClientServiceOAuth2AuthorizedClientManager(ClientRegistrationRepository clientRegistrationRepository,
+																OAuth2AuthorizedClientService authorizedClientService) {
+		Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
+		Assert.notNull(authorizedClientService, "authorizedClientService cannot be null");
+		this.clientRegistrationRepository = clientRegistrationRepository;
+		this.authorizedClientService = authorizedClientService;
+	}
+
+	@Nullable
+	@Override
+	public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) {
+		Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");
+
+		String clientRegistrationId = authorizeRequest.getClientRegistrationId();
+		OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient();
+		Authentication principal = authorizeRequest.getPrincipal();
+
+		OAuth2AuthorizationContext.Builder contextBuilder;
+		if (authorizedClient != null) {
+			contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
+		} else {
+			ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
+			Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'");
+			authorizedClient = this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName());
+			if (authorizedClient != null) {
+				contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient);
+			} else {
+				contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration);
+			}
+		}
+		OAuth2AuthorizationContext authorizationContext = contextBuilder
+				.principal(principal)
+				.attributes(this.contextAttributesMapper.apply(authorizeRequest))
+				.build();
+
+		authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
+		if (authorizedClient != null) {
+			this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
+		} else {
+			// In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported.
+			// For these cases, return the provided `authorizationContext.authorizedClient`.
+			if (authorizationContext.getAuthorizedClient() != null) {
+				return authorizationContext.getAuthorizedClient();
+			}
+		}
+
+		return authorizedClient;
+	}
+
+	/**
+	 * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client.
+	 *
+	 * @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client
+	 */
+	public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) {
+		Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null");
+		this.authorizedClientProvider = authorizedClientProvider;
+	}
+
+	/**
+	 * Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes
+	 * to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}.
+	 *
+	 * @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes
+	 *                                   to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}
+	 */
+	public void setContextAttributesMapper(Function<OAuth2AuthorizeRequest, Map<String, Object>> contextAttributesMapper) {
+		Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null");
+		this.contextAttributesMapper = contextAttributesMapper;
+	}
+
+	/**
+	 * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}.
+	 */
+	public static class DefaultContextAttributesMapper implements Function<OAuth2AuthorizeRequest, Map<String, Object>> {
+
+		@Override
+		public Map<String, Object> apply(OAuth2AuthorizeRequest authorizeRequest) {
+			Map<String, Object> contextAttributes = Collections.emptyMap();
+			String scope = authorizeRequest.getAttribute(OAuth2ParameterNames.SCOPE);
+			if (StringUtils.hasText(scope)) {
+				contextAttributes = new HashMap<>();
+				contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
+						StringUtils.delimitedListToStringArray(scope, " "));
+			}
+			return contextAttributes;
+		}
+	}
+}

+ 280 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceOAuth2AuthorizedClientManagerTests.java

@@ -0,0 +1,280 @@
+/*
+ * Copyright 2002-2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
+import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+
+import java.util.function.Function;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.*;
+
+/**
+ * Tests for {@link AuthorizedClientServiceOAuth2AuthorizedClientManager}.
+ *
+ * @author Joe Grandja
+ */
+public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
+	private ClientRegistrationRepository clientRegistrationRepository;
+	private OAuth2AuthorizedClientService authorizedClientService;
+	private OAuth2AuthorizedClientProvider authorizedClientProvider;
+	private Function contextAttributesMapper;
+	private AuthorizedClientServiceOAuth2AuthorizedClientManager authorizedClientManager;
+	private ClientRegistration clientRegistration;
+	private Authentication principal;
+	private OAuth2AuthorizedClient authorizedClient;
+	private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
+
+	@SuppressWarnings("unchecked")
+	@Before
+	public void setup() {
+		this.clientRegistrationRepository = mock(ClientRegistrationRepository.class);
+		this.authorizedClientService = mock(OAuth2AuthorizedClientService.class);
+		this.authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class);
+		this.contextAttributesMapper = mock(Function.class);
+		this.authorizedClientManager = new AuthorizedClientServiceOAuth2AuthorizedClientManager(
+				this.clientRegistrationRepository, this.authorizedClientService);
+		this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider);
+		this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper);
+		this.clientRegistration = TestClientRegistrations.clientRegistration().build();
+		this.principal = new TestingAuthenticationToken("principal", "password");
+		this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
+				TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken());
+		this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class);
+	}
+
+	@Test
+	public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new AuthorizedClientServiceOAuth2AuthorizedClientManager(null, this.authorizedClientService))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("clientRegistrationRepository cannot be null");
+	}
+
+	@Test
+	public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new AuthorizedClientServiceOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizedClientService cannot be null");
+	}
+
+	@Test
+	public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizedClientProvider cannot be null");
+	}
+
+	@Test
+	public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("contextAttributesMapper cannot be null");
+	}
+
+	@Test
+	public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientManager.authorize(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizeRequest cannot be null");
+	}
+
+	@Test
+	public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() {
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id")
+				.principal(this.principal)
+				.build();
+		assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'");
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration);
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		assertThat(authorizedClient).isNull();
+		verify(this.authorizedClientService, never()).saveAuthorizedClient(
+				any(OAuth2AuthorizedClient.class), eq(this.principal));
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration);
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(this.authorizedClient);
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isNull();
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		assertThat(authorizedClient).isSameAs(this.authorizedClient);
+		verify(this.authorizedClientService).saveAuthorizedClient(
+				eq(this.authorizedClient), eq(this.principal));
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
+		when(this.clientRegistrationRepository.findByRegistrationId(
+				eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration);
+		when(this.authorizedClientService.loadAuthorizedClient(
+				eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName()))).thenReturn(this.authorizedClient);
+
+		OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(
+				this.clientRegistration, this.principal.getName(),
+				TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient);
+
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
+				.principal(this.principal)
+				.build();
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		assertThat(authorizedClient).isSameAs(reauthorizedClient);
+		verify(this.authorizedClientService).saveAuthorizedClient(
+				eq(reauthorizedClient), eq(this.principal));
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
+		OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
+				.principal(this.principal)
+				.build();
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		assertThat(authorizedClient).isSameAs(this.authorizedClient);
+		verify(this.authorizedClientService, never()).saveAuthorizedClient(
+				any(OAuth2AuthorizedClient.class), eq(this.principal));
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void reauthorizeWhenSupportedProviderThenReauthorized() {
+		OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(
+				this.clientRegistration, this.principal.getName(),
+				TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient);
+
+		OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
+				.principal(this.principal)
+				.build();
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+		verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+
+		assertThat(authorizedClient).isSameAs(reauthorizedClient);
+		verify(this.authorizedClientService).saveAuthorizedClient(
+				eq(reauthorizedClient), eq(this.principal));
+	}
+
+	@SuppressWarnings("unchecked")
+	@Test
+	public void reauthorizeWhenRequestAttributeScopeThenMappedToContext() {
+		OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient(
+				this.clientRegistration, this.principal.getName(),
+				TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
+
+		when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient);
+
+		// Override the mock with the default
+		this.authorizedClientManager.setContextAttributesMapper(
+				new AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper());
+
+		OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
+				.principal(this.principal)
+				.attribute(OAuth2ParameterNames.SCOPE, "read write")
+				.build();
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
+
+		verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
+
+		OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
+		assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
+		assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
+		assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
+		assertThat(authorizationContext.getAttributes()).containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
+		String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
+		assertThat(requestScopeAttribute).contains("read", "write");
+
+		assertThat(authorizedClient).isSameAs(reauthorizedClient);
+		verify(this.authorizedClientService).saveAuthorizedClient(
+				eq(reauthorizedClient), eq(this.principal));
+	}
+}