|
@@ -1,5 +1,5 @@
|
|
/*
|
|
/*
|
|
- * Copyright 2002-2020 the original author or authors.
|
|
|
|
|
|
+ * Copyright 2002-2021 the original author or authors.
|
|
*
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -16,6 +16,7 @@
|
|
|
|
|
|
package org.springframework.security.oauth2.client.web;
|
|
package org.springframework.security.oauth2.client.web;
|
|
|
|
|
|
|
|
+import java.util.Collection;
|
|
import java.util.HashMap;
|
|
import java.util.HashMap;
|
|
import java.util.Map;
|
|
import java.util.Map;
|
|
|
|
|
|
@@ -33,10 +34,12 @@ import org.springframework.security.authentication.AuthenticationDetailsSource;
|
|
import org.springframework.security.authentication.AuthenticationManager;
|
|
import org.springframework.security.authentication.AuthenticationManager;
|
|
import org.springframework.security.core.Authentication;
|
|
import org.springframework.security.core.Authentication;
|
|
import org.springframework.security.core.AuthenticationException;
|
|
import org.springframework.security.core.AuthenticationException;
|
|
|
|
+import org.springframework.security.core.GrantedAuthority;
|
|
import org.springframework.security.core.authority.AuthorityUtils;
|
|
import org.springframework.security.core.authority.AuthorityUtils;
|
|
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
|
|
import org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService;
|
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
|
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
|
|
|
|
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
|
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
|
|
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
|
@@ -152,6 +155,12 @@ public class OAuth2LoginAuthenticationFilterTests {
|
|
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null));
|
|
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthorizationRequestRepository(null));
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ // gh-10033
|
|
|
|
+ @Test
|
|
|
|
+ public void setAuthenticationResultConverterWhenAuthenticationResultConverterIsNullThenThrowIllegalArgumentException() {
|
|
|
|
+ assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationResultConverter(null));
|
|
|
|
+ }
|
|
|
|
+
|
|
@Test
|
|
@Test
|
|
public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception {
|
|
public void doFilterWhenNotAuthorizationResponseThenNextFilter() throws Exception {
|
|
String requestUri = "/path";
|
|
String requestUri = "/path";
|
|
@@ -416,6 +425,41 @@ public class OAuth2LoginAuthenticationFilterTests {
|
|
assertThat(result.getDetails()).isEqualTo(webAuthenticationDetails);
|
|
assertThat(result.getDetails()).isEqualTo(webAuthenticationDetails);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ // gh-10033
|
|
|
|
+ @Test
|
|
|
|
+ public void attemptAuthenticationWhenAuthenticationResultIsNullThenIllegalArgumentException() throws Exception {
|
|
|
|
+ this.filter.setAuthenticationResultConverter((authentication) -> null);
|
|
|
|
+ String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
|
|
|
|
+ String state = "state";
|
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
|
+ request.setServletPath(requestUri);
|
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, state);
|
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
|
+ this.setUpAuthorizationRequest(request, response, this.registration1, state);
|
|
|
|
+ this.setUpAuthenticationResult(this.registration1);
|
|
|
|
+ assertThatIllegalArgumentException().isThrownBy(() -> this.filter.attemptAuthentication(request, response));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // gh-10033
|
|
|
|
+ @Test
|
|
|
|
+ public void attemptAuthenticationWhenAuthenticationResultConverterSetThenUsed() {
|
|
|
|
+ this.filter.setAuthenticationResultConverter(
|
|
|
|
+ (authentication) -> new CustomOAuth2AuthenticationToken(authentication.getPrincipal(),
|
|
|
|
+ authentication.getAuthorities(), authentication.getClientRegistration().getRegistrationId()));
|
|
|
|
+ String requestUri = "/login/oauth2/code/" + this.registration1.getRegistrationId();
|
|
|
|
+ String state = "state";
|
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
|
|
|
+ request.setServletPath(requestUri);
|
|
|
|
+ request.addParameter(OAuth2ParameterNames.CODE, "code");
|
|
|
|
+ request.addParameter(OAuth2ParameterNames.STATE, state);
|
|
|
|
+ MockHttpServletResponse response = new MockHttpServletResponse();
|
|
|
|
+ this.setUpAuthorizationRequest(request, response, this.registration1, state);
|
|
|
|
+ this.setUpAuthenticationResult(this.registration1);
|
|
|
|
+ Authentication authenticationResult = this.filter.attemptAuthentication(request, response);
|
|
|
|
+ assertThat(authenticationResult).isInstanceOf(CustomOAuth2AuthenticationToken.class);
|
|
|
|
+ }
|
|
|
|
+
|
|
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
|
private void setUpAuthorizationRequest(HttpServletRequest request, HttpServletResponse response,
|
|
ClientRegistration registration, String state) {
|
|
ClientRegistration registration, String state) {
|
|
Map<String, Object> attributes = new HashMap<>();
|
|
Map<String, Object> attributes = new HashMap<>();
|
|
@@ -454,4 +498,13 @@ public class OAuth2LoginAuthenticationFilterTests {
|
|
given(this.authenticationManager.authenticate(any(Authentication.class))).willReturn(this.loginAuthentication);
|
|
given(this.authenticationManager.authenticate(any(Authentication.class))).willReturn(this.loginAuthentication);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ private static final class CustomOAuth2AuthenticationToken extends OAuth2AuthenticationToken {
|
|
|
|
+
|
|
|
|
+ CustomOAuth2AuthenticationToken(OAuth2User principal, Collection<? extends GrantedAuthority> authorities,
|
|
|
|
+ String authorizedClientRegistrationId) {
|
|
|
|
+ super(principal, authorities, authorizedClientRegistrationId);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ }
|
|
|
|
+
|
|
}
|
|
}
|