Browse Source

Save original request on oauth2Client filter

When we used the oauth2Client directive and requested an endpoint that
required client authorization on the authorization server, the
SPRING_SECURITY_SAVED_REQUEST was not persisted, and therefore after
creating the authorized client we were redirected to the root page ("/").

Now we are storing the session attribute and getting redirected back to
the original URI as expected.

Note that the attribute is stored only when a
ClientAuthorizationRequiredException is thrown in the chain, we dont
want to store it as a response to the
/oauth2/authorization/{registrationId} endpoint, since we would end
up in an infinite loop

Fixes gh-6341
Gerardo Roza 6 years ago
parent
commit
95e0e7243d

+ 17 - 2
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilter.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -24,6 +24,8 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.web.server.DefaultServerRedirectStrategy;
 import org.springframework.security.web.server.DefaultServerRedirectStrategy;
 import org.springframework.security.web.server.ServerRedirectStrategy;
 import org.springframework.security.web.server.ServerRedirectStrategy;
+import org.springframework.security.web.server.savedrequest.ServerRequestCache;
+import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilter;
@@ -67,6 +69,7 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter {
 	private final ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver;
 	private final ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver;
 	private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
 	private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
 		new WebSessionOAuth2ServerAuthorizationRequestRepository();
 		new WebSessionOAuth2ServerAuthorizationRequestRepository();
+	private ServerRequestCache requestCache = new WebSessionServerRequestCache();
 
 
 	/**
 	/**
 	 * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters.
 	 * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided parameters.
@@ -98,11 +101,23 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter {
 		this.authorizationRequestRepository = authorizationRequestRepository;
 		this.authorizationRequestRepository = authorizationRequestRepository;
 	}
 	}
 
 
+	/**
+	 * The request cache to use to save the request before sending a redirect.
+	 * @param requestCache the cache to redirect to.
+	 */
+	public void setRequestCache(ServerRequestCache requestCache) {
+		Assert.notNull(requestCache, "requestCache cannot be null");
+		this.requestCache = requestCache;
+	}
+
 	@Override
 	@Override
 	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 		return this.authorizationRequestResolver.resolve(exchange)
 		return this.authorizationRequestResolver.resolve(exchange)
 			.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
 			.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
-			.onErrorResume(ClientAuthorizationRequiredException.class, e -> this.authorizationRequestResolver.resolve(exchange, e.getClientRegistrationId()))
+			.onErrorResume(ClientAuthorizationRequiredException.class, e -> {
+				return this.requestCache.saveRequest(exchange)
+					.then(this.authorizationRequestResolver.resolve(exchange, e.getClientRegistrationId()));
+			})
 			.flatMap(clientRegistration -> sendRedirectForAuthorization(exchange, clientRegistration));
 			.flatMap(clientRegistration -> sendRedirectForAuthorization(exchange, clientRegistration));
 	}
 	}
 
 

+ 34 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationRequestRedirectWebFilterTests.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.web.server.savedrequest.ServerRequestCache;
 import org.springframework.test.web.reactive.server.FluxExchangeResult;
 import org.springframework.test.web.reactive.server.FluxExchangeResult;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.web.server.handler.FilteringWebHandler;
 import org.springframework.web.server.handler.FilteringWebHandler;
@@ -53,6 +54,9 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests {
 	@Mock
 	@Mock
 	private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authzRequestRepository;
 	private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authzRequestRepository;
 
 
+	@Mock
+	private ServerRequestCache requestCache;
+
 	private ClientRegistration registration = TestClientRegistrations.clientRegistration().build();
 	private ClientRegistration registration = TestClientRegistrations.clientRegistration().build();
 
 
 	private OAuth2AuthorizationRequestRedirectWebFilter filter;
 	private OAuth2AuthorizationRequestRedirectWebFilter filter;
@@ -139,4 +143,33 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests {
 				.is3xxRedirection()
 				.is3xxRedirection()
 				.returnResult(String.class);
 				.returnResult(String.class);
 	}
 	}
+
+	@Test
+	public void filterWhenExceptionThenSaveRequestSessionAttribute() {
+		this.filter.setRequestCache(this.requestCache);
+		when(this.requestCache.saveRequest(any())).thenReturn(Mono.empty());
+		FilteringWebHandler webHandler = new FilteringWebHandler(
+				e -> Mono.error(new ClientAuthorizationRequiredException(this.registration.getRegistrationId())),
+				Arrays.asList(this.filter));
+		this.client = WebTestClient.bindToWebHandler(webHandler).build();
+		this.client.get()
+				.uri("https://example.com/foo")
+				.exchange()
+				.expectStatus()
+				.is3xxRedirection()
+				.returnResult(String.class);
+		verify(this.requestCache).saveRequest(any());
+	}
+
+	@Test
+	public void filterWhenPathMatchesThenRequestSessionAttributeNotSaved() {
+		this.filter.setRequestCache(this.requestCache);
+		this.client.get()
+				.uri("https://example.com/oauth2/authorization/registration-id")
+				.exchange()
+				.expectStatus()
+				.is3xxRedirection()
+				.returnResult(String.class);
+		verifyZeroInteractions(this.requestCache);
+	}
 }
 }