|
@@ -143,12 +143,12 @@ public class OAuth2LoginApplicationTests {
|
|
Map<String, String> params = uriComponents.getQueryParams().toSingleValueMap();
|
|
Map<String, String> params = uriComponents.getQueryParams().toSingleValueMap();
|
|
|
|
|
|
assertThat(params.get(OAuth2ParameterNames.RESPONSE_TYPE))
|
|
assertThat(params.get(OAuth2ParameterNames.RESPONSE_TYPE))
|
|
- .isEqualTo(OAuth2AuthorizationResponseType.CODE.getValue());
|
|
|
|
|
|
+ .isEqualTo(OAuth2AuthorizationResponseType.CODE.getValue());
|
|
assertThat(params.get(OAuth2ParameterNames.CLIENT_ID)).isEqualTo(clientRegistration.getClientId());
|
|
assertThat(params.get(OAuth2ParameterNames.CLIENT_ID)).isEqualTo(clientRegistration.getClientId());
|
|
String redirectUri = AUTHORIZE_BASE_URL + "/" + clientRegistration.getRegistrationId();
|
|
String redirectUri = AUTHORIZE_BASE_URL + "/" + clientRegistration.getRegistrationId();
|
|
assertThat(URLDecoder.decode(params.get(OAuth2ParameterNames.REDIRECT_URI), "UTF-8")).isEqualTo(redirectUri);
|
|
assertThat(URLDecoder.decode(params.get(OAuth2ParameterNames.REDIRECT_URI), "UTF-8")).isEqualTo(redirectUri);
|
|
assertThat(URLDecoder.decode(params.get(OAuth2ParameterNames.SCOPE), "UTF-8"))
|
|
assertThat(URLDecoder.decode(params.get(OAuth2ParameterNames.SCOPE), "UTF-8"))
|
|
- .isEqualTo(clientRegistration.getScopes().stream().collect(Collectors.joining(" ")));
|
|
|
|
|
|
+ .isEqualTo(clientRegistration.getScopes().stream().collect(Collectors.joining(" ")));
|
|
assertThat(params.get(OAuth2ParameterNames.STATE)).isNotNull();
|
|
assertThat(params.get(OAuth2ParameterNames.STATE)).isNotNull();
|
|
}
|
|
}
|
|
|
|
|
|
@@ -185,7 +185,8 @@ public class OAuth2LoginApplicationTests {
|
|
WebResponse response = this.followLinkDisableRedirects(clientAnchorElement);
|
|
WebResponse response = this.followLinkDisableRedirects(clientAnchorElement);
|
|
|
|
|
|
UriComponents authorizeRequestUriComponents = UriComponentsBuilder
|
|
UriComponents authorizeRequestUriComponents = UriComponentsBuilder
|
|
- .fromUri(URI.create(response.getResponseHeaderValue("Location"))).build();
|
|
|
|
|
|
+ .fromUri(URI.create(response.getResponseHeaderValue("Location")))
|
|
|
|
+ .build();
|
|
|
|
|
|
Map<String, String> params = authorizeRequestUriComponents.getQueryParams().toSingleValueMap();
|
|
Map<String, String> params = authorizeRequestUriComponents.getQueryParams().toSingleValueMap();
|
|
String code = "auth-code";
|
|
String code = "auth-code";
|
|
@@ -193,8 +194,11 @@ public class OAuth2LoginApplicationTests {
|
|
String redirectUri = URLDecoder.decode(params.get(OAuth2ParameterNames.REDIRECT_URI), "UTF-8");
|
|
String redirectUri = URLDecoder.decode(params.get(OAuth2ParameterNames.REDIRECT_URI), "UTF-8");
|
|
|
|
|
|
String authorizationResponseUri = UriComponentsBuilder.fromHttpUrl(redirectUri)
|
|
String authorizationResponseUri = UriComponentsBuilder.fromHttpUrl(redirectUri)
|
|
- .queryParam(OAuth2ParameterNames.CODE, code).queryParam(OAuth2ParameterNames.STATE, state).build()
|
|
|
|
- .encode().toUriString();
|
|
|
|
|
|
+ .queryParam(OAuth2ParameterNames.CODE, code)
|
|
|
|
+ .queryParam(OAuth2ParameterNames.STATE, state)
|
|
|
|
+ .build()
|
|
|
|
+ .encode()
|
|
|
|
+ .toUriString();
|
|
|
|
|
|
page = this.webClient.getPage(new URL(authorizationResponseUri));
|
|
page = this.webClient.getPage(new URL(authorizationResponseUri));
|
|
this.assertIndexPage(page);
|
|
this.assertIndexPage(page);
|
|
@@ -214,8 +218,11 @@ public class OAuth2LoginApplicationTests {
|
|
String redirectUri = AUTHORIZE_BASE_URL + "/" + clientRegistration.getRegistrationId();
|
|
String redirectUri = AUTHORIZE_BASE_URL + "/" + clientRegistration.getRegistrationId();
|
|
|
|
|
|
String authorizationResponseUri = UriComponentsBuilder.fromHttpUrl(redirectUri)
|
|
String authorizationResponseUri = UriComponentsBuilder.fromHttpUrl(redirectUri)
|
|
- .queryParam(OAuth2ParameterNames.CODE, code).queryParam(OAuth2ParameterNames.STATE, state).build()
|
|
|
|
- .encode().toUriString();
|
|
|
|
|
|
+ .queryParam(OAuth2ParameterNames.CODE, code)
|
|
|
|
+ .queryParam(OAuth2ParameterNames.STATE, state)
|
|
|
|
+ .build()
|
|
|
|
+ .encode()
|
|
|
|
+ .toUriString();
|
|
|
|
|
|
// Clear session cookie will ensure the 'session-saved'
|
|
// Clear session cookie will ensure the 'session-saved'
|
|
// Authorization Request (from previous request) is not found
|
|
// Authorization Request (from previous request) is not found
|
|
@@ -246,8 +253,11 @@ public class OAuth2LoginApplicationTests {
|
|
String redirectUri = AUTHORIZE_BASE_URL + "/" + clientRegistration.getRegistrationId();
|
|
String redirectUri = AUTHORIZE_BASE_URL + "/" + clientRegistration.getRegistrationId();
|
|
|
|
|
|
String authorizationResponseUri = UriComponentsBuilder.fromHttpUrl(redirectUri)
|
|
String authorizationResponseUri = UriComponentsBuilder.fromHttpUrl(redirectUri)
|
|
- .queryParam(OAuth2ParameterNames.CODE, code).queryParam(OAuth2ParameterNames.STATE, state).build()
|
|
|
|
- .encode().toUriString();
|
|
|
|
|
|
+ .queryParam(OAuth2ParameterNames.CODE, code)
|
|
|
|
+ .queryParam(OAuth2ParameterNames.STATE, state)
|
|
|
|
+ .build()
|
|
|
|
+ .encode()
|
|
|
|
+ .toUriString();
|
|
|
|
|
|
page = this.webClient.getPage(new URL(authorizationResponseUri));
|
|
page = this.webClient.getPage(new URL(authorizationResponseUri));
|
|
assertThat(page.getBaseURL()).isEqualTo(loginErrorPageUrl);
|
|
assertThat(page.getBaseURL()).isEqualTo(loginErrorPageUrl);
|
|
@@ -261,8 +271,9 @@ public class OAuth2LoginApplicationTests {
|
|
void requestWhenMockOAuth2LoginThenIndex() throws Exception {
|
|
void requestWhenMockOAuth2LoginThenIndex() throws Exception {
|
|
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("github");
|
|
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("github");
|
|
this.mvc.perform(get("/").with(oauth2Login().clientRegistration(clientRegistration)))
|
|
this.mvc.perform(get("/").with(oauth2Login().clientRegistration(clientRegistration)))
|
|
- .andExpect(model().attribute("userName", "user")).andExpect(model().attribute("clientName", "GitHub"))
|
|
|
|
- .andExpect(model().attribute("userAttributes", Collections.singletonMap("sub", "user")));
|
|
|
|
|
|
+ .andExpect(model().attribute("userName", "user"))
|
|
|
|
+ .andExpect(model().attribute("clientName", "GitHub"))
|
|
|
|
+ .andExpect(model().attribute("userAttributes", Collections.singletonMap("sub", "user")));
|
|
}
|
|
}
|
|
|
|
|
|
private void assertLoginPage(HtmlPage page) {
|
|
private void assertLoginPage(HtmlPage page) {
|
|
@@ -276,10 +287,10 @@ public class OAuth2LoginApplicationTests {
|
|
ClientRegistration googleClientRegistration = this.clientRegistrationRepository.findByRegistrationId("google");
|
|
ClientRegistration googleClientRegistration = this.clientRegistrationRepository.findByRegistrationId("google");
|
|
ClientRegistration githubClientRegistration = this.clientRegistrationRepository.findByRegistrationId("github");
|
|
ClientRegistration githubClientRegistration = this.clientRegistrationRepository.findByRegistrationId("github");
|
|
ClientRegistration facebookClientRegistration = this.clientRegistrationRepository
|
|
ClientRegistration facebookClientRegistration = this.clientRegistrationRepository
|
|
- .findByRegistrationId("facebook");
|
|
|
|
|
|
+ .findByRegistrationId("facebook");
|
|
ClientRegistration oktaClientRegistration = this.clientRegistrationRepository.findByRegistrationId("okta");
|
|
ClientRegistration oktaClientRegistration = this.clientRegistrationRepository.findByRegistrationId("okta");
|
|
ClientRegistration springClientRegistration = this.clientRegistrationRepository
|
|
ClientRegistration springClientRegistration = this.clientRegistrationRepository
|
|
- .findByRegistrationId("login-client");
|
|
|
|
|
|
+ .findByRegistrationId("login-client");
|
|
|
|
|
|
String baseAuthorizeUri = AUTHORIZATION_BASE_URI + "/";
|
|
String baseAuthorizeUri = AUTHORIZATION_BASE_URI + "/";
|
|
String googleClientAuthorizeUri = baseAuthorizeUri + googleClientRegistration.getRegistrationId();
|
|
String googleClientAuthorizeUri = baseAuthorizeUri + googleClientRegistration.getRegistrationId();
|
|
@@ -304,12 +315,14 @@ public class OAuth2LoginApplicationTests {
|
|
DomNodeList<HtmlElement> divElements = page.getBody().getElementsByTagName("div");
|
|
DomNodeList<HtmlElement> divElements = page.getBody().getElementsByTagName("div");
|
|
assertThat(divElements.get(1).asNormalizedText()).contains("User: joeg@springsecurity.io");
|
|
assertThat(divElements.get(1).asNormalizedText()).contains("User: joeg@springsecurity.io");
|
|
assertThat(divElements.get(4).asNormalizedText())
|
|
assertThat(divElements.get(4).asNormalizedText())
|
|
- .contains("You are successfully logged in joeg@springsecurity.io");
|
|
|
|
|
|
+ .contains("You are successfully logged in joeg@springsecurity.io");
|
|
}
|
|
}
|
|
|
|
|
|
private HtmlAnchor getClientAnchorElement(HtmlPage page, ClientRegistration clientRegistration) {
|
|
private HtmlAnchor getClientAnchorElement(HtmlPage page, ClientRegistration clientRegistration) {
|
|
- Optional<HtmlAnchor> clientAnchorElement = page.getAnchors().stream()
|
|
|
|
- .filter((e) -> e.asNormalizedText().equals(clientRegistration.getClientName())).findFirst();
|
|
|
|
|
|
+ Optional<HtmlAnchor> clientAnchorElement = page.getAnchors()
|
|
|
|
+ .stream()
|
|
|
|
+ .filter((e) -> e.asNormalizedText().equals(clientRegistration.getClientName()))
|
|
|
|
+ .findFirst();
|
|
|
|
|
|
return (clientAnchorElement.orElse(null));
|
|
return (clientAnchorElement.orElse(null));
|
|
}
|
|
}
|
|
@@ -350,7 +363,9 @@ public class OAuth2LoginApplicationTests {
|
|
|
|
|
|
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> mockAccessTokenResponseClient() {
|
|
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> mockAccessTokenResponseClient() {
|
|
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234")
|
|
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234")
|
|
- .tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(60 * 1000).build();
|
|
|
|
|
|
+ .tokenType(OAuth2AccessToken.TokenType.BEARER)
|
|
|
|
+ .expiresIn(60 * 1000)
|
|
|
|
+ .build();
|
|
|
|
|
|
OAuth2AccessTokenResponseClient tokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
|
|
OAuth2AccessTokenResponseClient tokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
|
|
when(tokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
|
|
when(tokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
|