|
@@ -19,9 +19,13 @@ import java.time.Instant;
|
|
|
import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
|
import java.util.HashSet;
|
|
|
+import java.util.List;
|
|
|
import java.util.Set;
|
|
|
+import java.util.function.Consumer;
|
|
|
import java.util.function.Function;
|
|
|
|
|
|
+import javax.servlet.http.HttpServletResponse;
|
|
|
+
|
|
|
import com.nimbusds.jose.jwk.JWKSet;
|
|
|
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
|
|
|
import com.nimbusds.jose.jwk.source.JWKSource;
|
|
@@ -30,10 +34,13 @@ import org.junit.Before;
|
|
|
import org.junit.BeforeClass;
|
|
|
import org.junit.Rule;
|
|
|
import org.junit.Test;
|
|
|
+import org.mockito.ArgumentCaptor;
|
|
|
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
import org.springframework.context.annotation.Bean;
|
|
|
import org.springframework.http.HttpHeaders;
|
|
|
+import org.springframework.http.HttpStatus;
|
|
|
+import org.springframework.security.authentication.AuthenticationProvider;
|
|
|
import org.springframework.security.config.Customizer;
|
|
|
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
|
|
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
|
|
@@ -61,11 +68,15 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
|
|
|
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
|
|
import org.springframework.security.oauth2.server.authorization.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
|
|
|
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationContext;
|
|
|
+import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationProvider;
|
|
|
import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken;
|
|
|
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
|
|
|
import org.springframework.security.oauth2.server.authorization.test.SpringTestRule;
|
|
|
import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
|
|
|
import org.springframework.security.web.SecurityFilterChain;
|
|
|
+import org.springframework.security.web.authentication.AuthenticationConverter;
|
|
|
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
|
|
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
|
|
import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
|
|
|
import org.springframework.security.web.context.SecurityContextRepository;
|
|
|
import org.springframework.security.web.util.matcher.RequestMatcher;
|
|
@@ -74,8 +85,15 @@ import org.springframework.test.web.servlet.MvcResult;
|
|
|
import org.springframework.test.web.servlet.ResultMatcher;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
+import static org.mockito.ArgumentMatchers.any;
|
|
|
+import static org.mockito.ArgumentMatchers.eq;
|
|
|
+import static org.mockito.Mockito.doAnswer;
|
|
|
+import static org.mockito.Mockito.mock;
|
|
|
import static org.mockito.Mockito.reset;
|
|
|
import static org.mockito.Mockito.spy;
|
|
|
+import static org.mockito.Mockito.verify;
|
|
|
+import static org.mockito.Mockito.verifyNoInteractions;
|
|
|
+import static org.mockito.Mockito.when;
|
|
|
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
|
|
|
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
|
|
|
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
|
|
@@ -99,17 +117,48 @@ public class OidcUserInfoTests {
|
|
|
@Autowired
|
|
|
private JwtEncoder jwtEncoder;
|
|
|
|
|
|
+ @Autowired
|
|
|
+ private JwtDecoder jwtDecoder;
|
|
|
+
|
|
|
@Autowired
|
|
|
private OAuth2AuthorizationService authorizationService;
|
|
|
|
|
|
+ private static AuthenticationConverter authenticationConverter;
|
|
|
+
|
|
|
+ private static Consumer<List<AuthenticationConverter>> authenticationConvertersConsumer;
|
|
|
+
|
|
|
+ private static AuthenticationProvider authenticationProvider;
|
|
|
+
|
|
|
+ private static Consumer<List<AuthenticationProvider>> authenticationProvidersConsumer;
|
|
|
+
|
|
|
+ private static AuthenticationSuccessHandler authenticationSuccessHandler;
|
|
|
+
|
|
|
+ private static AuthenticationFailureHandler authenticationFailureHandler;
|
|
|
+
|
|
|
+ private static Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper;
|
|
|
+
|
|
|
@BeforeClass
|
|
|
public static void init() {
|
|
|
securityContextRepository = spy(new HttpSessionSecurityContextRepository());
|
|
|
+ authenticationConverter = mock(AuthenticationConverter.class);
|
|
|
+ authenticationConvertersConsumer = mock(Consumer.class);
|
|
|
+ authenticationProvider = mock(AuthenticationProvider.class);
|
|
|
+ authenticationProvidersConsumer = mock(Consumer.class);
|
|
|
+ authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class);
|
|
|
+ authenticationFailureHandler = mock(AuthenticationFailureHandler.class);
|
|
|
+ userInfoMapper = mock(Function.class);
|
|
|
}
|
|
|
|
|
|
@Before
|
|
|
public void setup() {
|
|
|
reset(securityContextRepository);
|
|
|
+ reset(authenticationConverter);
|
|
|
+ reset(authenticationConvertersConsumer);
|
|
|
+ reset(authenticationProvider);
|
|
|
+ reset(authenticationProvidersConsumer);
|
|
|
+ reset(authenticationSuccessHandler);
|
|
|
+ reset(authenticationFailureHandler);
|
|
|
+ reset(userInfoMapper);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -145,19 +194,89 @@ public class OidcUserInfoTests {
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void requestWhenSignedJwtAndCustomUserInfoMapperThenMapJwtClaimsToUserInfoResponse() throws Exception {
|
|
|
+ public void requestWhenUserInfoEndpointCustomizedThenUsed() throws Exception {
|
|
|
this.spring.register(CustomUserInfoConfiguration.class).autowire();
|
|
|
|
|
|
OAuth2Authorization authorization = createAuthorization();
|
|
|
this.authorizationService.save(authorization);
|
|
|
|
|
|
+ when(userInfoMapper.apply(any())).thenReturn(createUserInfo());
|
|
|
+
|
|
|
OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
|
|
|
// @formatter:off
|
|
|
this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
|
|
|
.header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue()))
|
|
|
- .andExpect(status().is2xxSuccessful())
|
|
|
- .andExpectAll(userInfoResponse());
|
|
|
+ .andExpect(status().is2xxSuccessful());
|
|
|
// @formatter:on
|
|
|
+ verify(userInfoMapper).apply(any());
|
|
|
+ verify(authenticationConverter).convert(any());
|
|
|
+ verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any());
|
|
|
+ verifyNoInteractions(authenticationFailureHandler);
|
|
|
+
|
|
|
+ ArgumentCaptor<List<AuthenticationProvider>> authenticationProvidersCaptor = ArgumentCaptor.forClass(List.class);
|
|
|
+ verify(authenticationProvidersConsumer).accept(authenticationProvidersCaptor.capture());
|
|
|
+ List<AuthenticationProvider> authenticationProviders = authenticationProvidersCaptor.getValue();
|
|
|
+ assertThat(authenticationProviders).hasSize(2).allMatch(provider ->
|
|
|
+ provider == authenticationProvider ||
|
|
|
+ provider instanceof OidcUserInfoAuthenticationProvider
|
|
|
+ );
|
|
|
+
|
|
|
+ ArgumentCaptor<List<AuthenticationConverter>> authenticationConvertersCaptor = ArgumentCaptor.forClass(List.class);
|
|
|
+ verify(authenticationConvertersConsumer).accept(authenticationConvertersCaptor.capture());
|
|
|
+ List<AuthenticationConverter> authenticationConverters = authenticationConvertersCaptor.getValue();
|
|
|
+ assertThat(authenticationConverters).hasSize(2).allMatch(AuthenticationConverter.class::isInstance);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void requestWhenUserInfoEndpointCustomizedThenAuthenticationProviderUsed() throws Exception {
|
|
|
+ this.spring.register(CustomUserInfoConfiguration.class).autowire();
|
|
|
+
|
|
|
+ OAuth2Authorization authorization = createAuthorization();
|
|
|
+ this.authorizationService.save(authorization);
|
|
|
+
|
|
|
+ when(authenticationProvider.supports(eq(OidcUserInfoAuthenticationToken.class))).thenReturn(true);
|
|
|
+ String tokenValue = authorization.getAccessToken().getToken().getTokenValue();
|
|
|
+ Jwt jwt = this.jwtDecoder.decode(tokenValue);
|
|
|
+ OidcUserInfoAuthenticationToken oidcUserInfoAuthentication = new OidcUserInfoAuthenticationToken(
|
|
|
+ new JwtAuthenticationToken(jwt), createUserInfo());
|
|
|
+ when(authenticationProvider.authenticate(any())).thenReturn(oidcUserInfoAuthentication);
|
|
|
+
|
|
|
+ OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
|
|
|
+ // @formatter:off
|
|
|
+ this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
|
|
|
+ .header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue()))
|
|
|
+ .andExpect(status().is2xxSuccessful());
|
|
|
+ // @formatter:on
|
|
|
+ verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any());
|
|
|
+ verify(authenticationProvider).authenticate(any());
|
|
|
+ verifyNoInteractions(authenticationFailureHandler);
|
|
|
+ verifyNoInteractions(userInfoMapper);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void requestWhenUserInfoEndpointCustomizedAndErrorThenUsed() throws Exception {
|
|
|
+ this.spring.register(CustomUserInfoConfiguration.class).autowire();
|
|
|
+ when(userInfoMapper.apply(any())).thenReturn(createUserInfo());
|
|
|
+ doAnswer(
|
|
|
+ invocation -> {
|
|
|
+ HttpServletResponse response = invocation.getArgument(1);
|
|
|
+ response.setStatus(HttpStatus.UNAUTHORIZED.value());
|
|
|
+ response.getWriter().write("unauthorized");
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+ ).when(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any());
|
|
|
+
|
|
|
+ OAuth2AccessToken accessToken = createAuthorization().getAccessToken().getToken();
|
|
|
+
|
|
|
+
|
|
|
+ // @formatter:off
|
|
|
+ this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI)
|
|
|
+ .header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue()))
|
|
|
+ .andExpect(status().is4xxClientError());
|
|
|
+ // @formatter:on
|
|
|
+ verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any());
|
|
|
+ verifyNoInteractions(authenticationSuccessHandler);
|
|
|
+ verifyNoInteractions(userInfoMapper);
|
|
|
}
|
|
|
|
|
|
// gh-482
|
|
@@ -271,14 +390,6 @@ public class OidcUserInfoTests {
|
|
|
RequestMatcher endpointsMatcher = authorizationServerConfigurer
|
|
|
.getEndpointsMatcher();
|
|
|
|
|
|
- // Custom User Info Mapper that retrieves claims from a signed JWT
|
|
|
- Function<OidcUserInfoAuthenticationContext, OidcUserInfo> userInfoMapper = context -> {
|
|
|
- OidcUserInfoAuthenticationToken authentication = context.getAuthentication();
|
|
|
- JwtAuthenticationToken principal = (JwtAuthenticationToken) authentication.getPrincipal();
|
|
|
-
|
|
|
- return new OidcUserInfo(principal.getToken().getClaims());
|
|
|
- };
|
|
|
-
|
|
|
// @formatter:off
|
|
|
http
|
|
|
.requestMatcher(endpointsMatcher)
|
|
@@ -290,6 +401,12 @@ public class OidcUserInfoTests {
|
|
|
.apply(authorizationServerConfigurer)
|
|
|
.oidc(oidc -> oidc
|
|
|
.userInfoEndpoint(userInfo -> userInfo
|
|
|
+ .userInfoRequestConverter(authenticationConverter)
|
|
|
+ .userInfoRequestConverters(authenticationConvertersConsumer)
|
|
|
+ .authenticationProvider(authenticationProvider)
|
|
|
+ .authenticationProviders(authenticationProvidersConsumer)
|
|
|
+ .userInfoResponseHandler(authenticationSuccessHandler)
|
|
|
+ .errorResponseHandler(authenticationFailureHandler)
|
|
|
.userInfoMapper(userInfoMapper)
|
|
|
)
|
|
|
);
|