Pārlūkot izejas kodu

Add converter for authentication result in OAuth2LoginAuthenticationFilter

Closes gh-10033
Steve Riesenberg 4 gadi atpakaļ
vecāks
revīzija
6d6dc113d8

+ 25 - 4
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.web;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
@@ -111,6 +112,8 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
 
 	private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = new HttpSessionOAuth2AuthorizationRequestRepository();
 
+	private Converter<OAuth2LoginAuthenticationToken, OAuth2AuthenticationToken> authenticationResultConverter = this::createAuthenticationResult;
+
 	/**
 	 * Constructs an {@code OAuth2LoginAuthenticationFilter} using the provided
 	 * parameters.
@@ -190,9 +193,9 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
 		authenticationRequest.setDetails(authenticationDetails);
 		OAuth2LoginAuthenticationToken authenticationResult = (OAuth2LoginAuthenticationToken) this
 				.getAuthenticationManager().authenticate(authenticationRequest);
-		OAuth2AuthenticationToken oauth2Authentication = new OAuth2AuthenticationToken(
-				authenticationResult.getPrincipal(), authenticationResult.getAuthorities(),
-				authenticationResult.getClientRegistration().getRegistrationId());
+		OAuth2AuthenticationToken oauth2Authentication = this.authenticationResultConverter
+				.convert(authenticationResult);
+		Assert.notNull(oauth2Authentication, "authentication result cannot be null");
 		oauth2Authentication.setDetails(authenticationDetails);
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
 				authenticationResult.getClientRegistration(), oauth2Authentication.getName(),
@@ -213,4 +216,22 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
 		this.authorizationRequestRepository = authorizationRequestRepository;
 	}
 
+	/**
+	 * Sets the converter responsible for converting from
+	 * {@link OAuth2LoginAuthenticationToken} to {@link OAuth2AuthenticationToken}
+	 * authentication result.
+	 * @param authenticationResultConverter the converter for
+	 * {@link OAuth2AuthenticationToken}'s
+	 */
+	public final void setAuthenticationResultConverter(
+			Converter<OAuth2LoginAuthenticationToken, OAuth2AuthenticationToken> authenticationResultConverter) {
+		Assert.notNull(authenticationResultConverter, "authenticationResultConverter cannot be null");
+		this.authenticationResultConverter = authenticationResultConverter;
+	}
+
+	private OAuth2AuthenticationToken createAuthenticationResult(OAuth2LoginAuthenticationToken authenticationResult) {
+		return new OAuth2AuthenticationToken(authenticationResult.getPrincipal(), authenticationResult.getAuthorities(),
+				authenticationResult.getClientRegistration().getRegistrationId());
+	}
+
 }

+ 54 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2LoginAuthenticationFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
 
 package org.springframework.security.oauth2.client.web;
 
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -33,10 +34,12 @@ import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
@@ -152,6 +155,12 @@ public class OAuth2LoginAuthenticationFilterTests {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null));
 	}
 
+	// gh-10033
+	@Test
+	public void setAuthenticationResultConverterWhenAuthenticationResultConverterIsNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationResultConverter(null));
+	}
+
 	@Test
 	public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception {
 		String requestUri = "/path";
@@ -416,6 +425,41 @@ public class OAuth2LoginAuthenticationFilterTests {
 		assertThat(result.getDetails()).isEqualTo(webAuthenticationDetails);
 	}
 
+	// gh-10033
+	@Test
+	public void attemptAuthenticationWhenAuthenticationResultIsNullThenIllegalArgumentException() throws Exception {
+		this.filter.setAuthenticationResultConverter((authentication) -> null);
+		String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
+		String state = "state";
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		request.addParameter(OAuth2ParameterNames.CODE, "code");
+		request.addParameter(OAuth2ParameterNames.STATE, state);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		this.setUpAuthorizationRequest(request, response, this.registration1, state);
+		this.setUpAuthenticationResult(this.registration1);
+		assertThatIllegalArgumentException().isThrownBy(() -> this.filter.attemptAuthentication(request, response));
+	}
+
+	// gh-10033
+	@Test
+	public void attemptAuthenticationWhenAuthenticationResultConverterSetThenUsed() {
+		this.filter.setAuthenticationResultConverter(
+				(authentication) -> new CustomOAuth2AuthenticationToken(authentication.getPrincipal(),
+						authentication.getAuthorities(), authentication.getClientRegistration().getRegistrationId()));
+		String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
+		String state = "state";
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		request.addParameter(OAuth2ParameterNames.CODE, "code");
+		request.addParameter(OAuth2ParameterNames.STATE, state);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		this.setUpAuthorizationRequest(request, response, this.registration1, state);
+		this.setUpAuthenticationResult(this.registration1);
+		Authentication authenticationResult = this.filter.attemptAuthentication(request, response);
+		assertThat(authenticationResult).isInstanceOf(CustomOAuth2AuthenticationToken.class);
+	}
+
 	private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
 			ClientRegistration registration, String state) {
 		Map<String, Object> attributes = new HashMap<>();
@@ -454,4 +498,13 @@ public class OAuth2LoginAuthenticationFilterTests {
 		given(this.authenticationManager.authenticate(any(Authentication.class))).willReturn(this.loginAuthentication);
 	}
 
+	private static final class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken {
+
+		CustomOAuth2AuthenticationToken(OAuth2User principal, Collection<? extends GrantedAuthority> authorities,
+				String authorizedClientRegistrationId) {
+			super(principal, authorities, authorizedClientRegistrationId);
+		}
+
+	}
+
 }