2
0
Эх сурвалжийг харах

Move logic from AuthorizationCodeAuthenticationFilter to OAuth2UserAuthenticationProvider

Joe Grandja 7 жил өмнө
parent
commit
df474e04d8

+ 65 - 6
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2UserAuthenticationProvider.java

@@ -20,6 +20,9 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationIdentifierStrategy;
 import org.springframework.security.oauth2.client.user.DefaultOAuth2UserService;
 import org.springframework.security.oauth2.client.user.OAuth2UserService;
 import org.springframework.security.oauth2.core.user.OAuth2User;
@@ -55,6 +58,7 @@ import java.util.Collection;
  * @see OidcUser
  */
 public class OAuth2UserAuthenticationProvider implements AuthenticationProvider {
+	private final ClientRegistrationIdentifierStrategy<String> providerIdentifierStrategy = new ProviderIdentifierStrategy();
 	private final OAuth2UserService userService;
 	private GrantedAuthoritiesMapper authoritiesMapper = (authorities -> authorities);
 
@@ -65,11 +69,18 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider
 
 	@Override
 	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
-		OAuth2UserAuthenticationToken oauth2UserAuthentication = (OAuth2UserAuthenticationToken) authentication;
+		OAuth2UserAuthenticationToken userAuthentication = (OAuth2UserAuthenticationToken) authentication;
+		OAuth2ClientAuthenticationToken clientAuthentication = userAuthentication.getClientAuthentication();
 
-		OAuth2ClientAuthenticationToken oauth2ClientAuthentication = oauth2UserAuthentication.getClientAuthentication();
+		if (this.userAuthenticated() && this.userAuthenticatedSameProviderAs(clientAuthentication)) {
+			// Create a new user authentication (using same principal)
+			// but with a different client authentication association
+			return this.createUserAuthentication(
+				(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication(),
+				clientAuthentication);
+		}
 
-		OAuth2User oauth2User = this.userService.loadUser(oauth2ClientAuthentication);
+		OAuth2User oauth2User = this.userService.loadUser(clientAuthentication);
 
 		Collection<? extends GrantedAuthority> mappedAuthorities =
 				this.authoritiesMapper.mapAuthorities(oauth2User.getAuthorities());
@@ -77,12 +88,12 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider
 		OAuth2UserAuthenticationToken authenticationResult;
 		if (OidcUser.class.isAssignableFrom(oauth2User.getClass())) {
 			authenticationResult = new OidcUserAuthenticationToken(
-				(OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)oauth2ClientAuthentication);
+				(OidcUser)oauth2User, mappedAuthorities, (OidcClientAuthenticationToken)clientAuthentication);
 		} else {
 			authenticationResult = new OAuth2UserAuthenticationToken(
-				oauth2User, mappedAuthorities, oauth2ClientAuthentication);
+				oauth2User, mappedAuthorities, clientAuthentication);
 		}
-		authenticationResult.setDetails(oauth2ClientAuthentication.getDetails());
+		authenticationResult.setDetails(clientAuthentication.getDetails());
 
 		return authenticationResult;
 	}
@@ -96,4 +107,52 @@ public class OAuth2UserAuthenticationProvider implements AuthenticationProvider
 	public boolean supports(Class<?> authentication) {
 		return OAuth2UserAuthenticationToken.class.isAssignableFrom(authentication);
 	}
+
+	private boolean userAuthenticated() {
+		Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
+		return currentAuthentication != null &&
+			currentAuthentication instanceof OAuth2UserAuthenticationToken &&
+			currentAuthentication.isAuthenticated();
+	}
+
+	private boolean userAuthenticatedSameProviderAs(OAuth2ClientAuthenticationToken clientAuthentication) {
+		OAuth2UserAuthenticationToken currentUserAuthentication =
+			(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
+
+		String userProviderId = this.providerIdentifierStrategy.getIdentifier(
+			currentUserAuthentication.getClientAuthentication().getClientRegistration());
+		String clientProviderId = this.providerIdentifierStrategy.getIdentifier(
+			clientAuthentication.getClientRegistration());
+
+		return userProviderId.equals(clientProviderId);
+	}
+
+	private OAuth2UserAuthenticationToken createUserAuthentication(
+		OAuth2UserAuthenticationToken currentUserAuthentication,
+		OAuth2ClientAuthenticationToken newClientAuthentication) {
+
+		if (OidcUserAuthenticationToken.class.isAssignableFrom(currentUserAuthentication.getClass())) {
+			return new OidcUserAuthenticationToken(
+				(OidcUser) currentUserAuthentication.getPrincipal(),
+				currentUserAuthentication.getAuthorities(),
+				newClientAuthentication);
+		} else {
+			return new OAuth2UserAuthenticationToken(
+				(OAuth2User)currentUserAuthentication.getPrincipal(),
+				currentUserAuthentication.getAuthorities(),
+				newClientAuthentication);
+		}
+	}
+
+	private static class ProviderIdentifierStrategy implements ClientRegistrationIdentifierStrategy<String> {
+
+		@Override
+		public String getIdentifier(ClientRegistration clientRegistration) {
+			StringBuilder builder = new StringBuilder();
+			builder.append("[").append(clientRegistration.getProviderDetails().getAuthorizationUri()).append("]");
+			builder.append("[").append(clientRegistration.getProviderDetails().getTokenUri()).append("]");
+			builder.append("[").append(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).append("]");
+			return builder.toString();
+		}
+	}
 }

+ 2 - 77
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilter.java

@@ -18,23 +18,17 @@ package org.springframework.security.oauth2.client.web;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
-import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationProvider;
 import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.client.authentication.OAuth2ClientAuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2UserAuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
-import org.springframework.security.oauth2.client.registration.ClientRegistrationIdentifierStrategy;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.endpoint.AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.AuthorizationResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2Parameter;
-import org.springframework.security.oauth2.core.user.OAuth2User;
-import org.springframework.security.oauth2.oidc.client.authentication.OidcClientAuthenticationToken;
-import org.springframework.security.oauth2.oidc.client.authentication.OidcUserAuthenticationToken;
-import org.springframework.security.oauth2.oidc.core.user.OidcUser;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
@@ -81,7 +75,6 @@ import java.io.IOException;
 public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticationProcessingFilter {
 	public static final String DEFAULT_AUTHORIZATION_RESPONSE_BASE_URI = "/oauth2/authorize/code";
 	private static final String AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE = "authorization_request_not_found";
-	private final ClientRegistrationIdentifierStrategy<String> providerIdentifierStrategy = new ProviderIdentifierStrategy();
 	private AuthorizationResponseMatcher authorizationResponseMatcher;
 	private ClientRegistrationRepository clientRegistrationRepository;
 	private AuthorizationRequestRepository authorizationRequestRepository = new HttpSessionAuthorizationRequestRepository();
@@ -135,20 +128,8 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
 		OAuth2ClientAuthenticationToken oauth2ClientAuthentication =
 			(OAuth2ClientAuthenticationToken)this.getAuthenticationManager().authenticate(authorizationCodeAuthentication);
 
-		OAuth2UserAuthenticationToken oauth2UserAuthentication;
-		if (this.authenticated() && this.authenticatedSameProviderAs(oauth2ClientAuthentication)) {
-			// Create a new user authentication (using same principal)
-			// but with a different client authentication association
-			oauth2UserAuthentication = (OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
-			oauth2UserAuthentication = this.createUserAuthentication(oauth2UserAuthentication, oauth2ClientAuthentication);
-		} else {
-			// Authenticate the user... the user needs to be authenticated
-			// before we can associate the client authentication to the user
-			oauth2UserAuthentication = (OAuth2UserAuthenticationToken)this.getAuthenticationManager().authenticate(
-				this.createUserAuthentication(oauth2ClientAuthentication));
-		}
-
-		return oauth2UserAuthentication;
+		return this.getAuthenticationManager().authenticate(
+			new OAuth2UserAuthenticationToken(oauth2ClientAuthentication));
 	}
 
 	public final RequestMatcher getAuthorizationResponseMatcher() {
@@ -171,50 +152,6 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
 		this.authorizationRequestRepository = authorizationRequestRepository;
 	}
 
-	private boolean authenticated() {
-		Authentication currentAuthentication = SecurityContextHolder.getContext().getAuthentication();
-		return currentAuthentication != null &&
-			currentAuthentication instanceof OAuth2UserAuthenticationToken &&
-			currentAuthentication.isAuthenticated();
-	}
-
-	private boolean authenticatedSameProviderAs(OAuth2ClientAuthenticationToken oauth2ClientAuthentication) {
-		OAuth2UserAuthenticationToken userAuthentication =
-			(OAuth2UserAuthenticationToken)SecurityContextHolder.getContext().getAuthentication();
-
-		String userProviderId = this.providerIdentifierStrategy.getIdentifier(
-			userAuthentication.getClientAuthentication().getClientRegistration());
-		String clientProviderId = this.providerIdentifierStrategy.getIdentifier(
-			oauth2ClientAuthentication.getClientRegistration());
-
-		return userProviderId.equals(clientProviderId);
-	}
-
-	private OAuth2UserAuthenticationToken createUserAuthentication(OAuth2ClientAuthenticationToken clientAuthentication) {
-		if (OidcClientAuthenticationToken.class.isAssignableFrom(clientAuthentication.getClass())) {
-			return new OidcUserAuthenticationToken((OidcClientAuthenticationToken)clientAuthentication);
-		} else {
-			return new OAuth2UserAuthenticationToken(clientAuthentication);
-		}
-	}
-
-	private OAuth2UserAuthenticationToken createUserAuthentication(
-		OAuth2UserAuthenticationToken currentUserAuthentication,
-		OAuth2ClientAuthenticationToken newClientAuthentication) {
-
-		if (OidcUserAuthenticationToken.class.isAssignableFrom(currentUserAuthentication.getClass())) {
-			return new OidcUserAuthenticationToken(
-				(OidcUser) currentUserAuthentication.getPrincipal(),
-				currentUserAuthentication.getAuthorities(),
-				newClientAuthentication);
-		} else {
-			return new OAuth2UserAuthenticationToken(
-				(OAuth2User)currentUserAuthentication.getPrincipal(),
-				currentUserAuthentication.getAuthorities(),
-				newClientAuthentication);
-		}
-	}
-
 	private static class AuthorizationResponseMatcher implements RequestMatcher {
 		private final String baseUri;
 
@@ -266,16 +203,4 @@ public class AuthorizationCodeAuthenticationFilter extends AbstractAuthenticatio
 			}
 		}
 	}
-
-	private static class ProviderIdentifierStrategy implements ClientRegistrationIdentifierStrategy<String> {
-
-		@Override
-		public String getIdentifier(ClientRegistration clientRegistration) {
-			StringBuilder builder = new StringBuilder();
-			builder.append("[").append(clientRegistration.getProviderDetails().getAuthorizationUri()).append("]");
-			builder.append("[").append(clientRegistration.getProviderDetails().getTokenUri()).append("]");
-			builder.append("[").append(clientRegistration.getProviderDetails().getUserInfoEndpoint().getUri()).append("]");
-			return builder.toString();
-		}
-	}
 }

+ 1 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/AuthorizationCodeAuthenticationFilterTests.java

@@ -128,7 +128,7 @@ public class AuthorizationCodeAuthenticationFilterTests {
 		ArgumentCaptor<Authentication> authenticationArgCaptor = ArgumentCaptor.forClass(Authentication.class);
 		Mockito.verify(successHandler).onAuthenticationSuccess(Matchers.any(HttpServletRequest.class), Matchers.any(HttpServletResponse.class),
 				authenticationArgCaptor.capture());
-		Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(userAuthentication);
+		Assertions.assertThat(authenticationArgCaptor.getValue()).isEqualTo(clientAuthentication);
 	}
 
 	@Test