瀏覽代碼

Add WebSessionOAuth2ReactiveAuthorizationRequestRepository

Issue: gh-4807
Rob Winch 7 年之前
父節點
當前提交
b613b2d253

+ 69 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/ReactiveAuthorizationRequestRepository.java

@@ -0,0 +1,69 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web;
+
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.web.server.ServerWebExchange;
+
+import reactor.core.publisher.Mono;
+
+/**
+ * Implementations of this interface are responsible for the persistence
+ * of {@link OAuth2AuthorizationRequest} between requests.
+ *
+ * <p>
+ * Used by the {@link OAuth2AuthorizationRequestRedirectFilter} for persisting the Authorization Request
+ * before it initiates the authorization code grant flow.
+ * As well, used by the {@link OAuth2LoginAuthenticationFilter} for resolving
+ * the associated Authorization Request when handling the callback of the Authorization Response.
+ *
+ * @author Rob Winch
+ * @since 5.1
+ * @see OAuth2AuthorizationRequest
+ * @see HttpSessionOAuth2AuthorizationRequestRepository
+ *
+ * @param <T> The type of OAuth 2.0 Authorization Request
+ */
+public interface ReactiveAuthorizationRequestRepository<T extends OAuth2AuthorizationRequest> {
+
+	/**
+	 * Returns the {@link OAuth2AuthorizationRequest} associated to the provided {@code HttpServletRequest}
+	 * or {@code null} if not available.
+	 *
+	 * @param exchange the {@code ServerWebExchange}
+	 * @return the {@link OAuth2AuthorizationRequest} or {@code null} if not available
+	 */
+	Mono<T> loadAuthorizationRequest(ServerWebExchange exchange);
+
+	/**
+	 * Persists the {@link OAuth2AuthorizationRequest} associating it to
+	 * the provided {@code HttpServletRequest} and/or {@code HttpServletResponse}.
+	 *
+	 * @param authorizationRequest the {@link OAuth2AuthorizationRequest}
+	 * @param exchange             the {@code ServerWebExchange}
+	 */
+	Mono<Void> saveAuthorizationRequest(T authorizationRequest, ServerWebExchange exchange);
+
+	/**
+	 * Removes and returns the {@link OAuth2AuthorizationRequest} associated to the
+	 * provided {@code HttpServletRequest} or if not available returns {@code null}.
+	 *
+	 * @param exchange the {@code ServerWebExchange}
+	 * @return the removed {@link OAuth2AuthorizationRequest} or {@code null} if not available
+	 */
+	Mono<T> removeAuthorizationRequest(ServerWebExchange exchange);
+}

+ 120 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/WebSessionOAuth2ReactiveAuthorizationRequestRepository.java

@@ -0,0 +1,120 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.util.Assert;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebSession;
+
+import reactor.core.publisher.Mono;
+
+/**
+ * An implementation of an {@link ReactiveAuthorizationRequestRepository} that stores
+ * {@link OAuth2AuthorizationRequest} in the {@code WebSession}.
+ *
+ * @author Rob Winch
+ * @since 5.1
+ * @see AuthorizationRequestRepository
+ * @see OAuth2AuthorizationRequest
+ */
+public final class WebSessionOAuth2ReactiveAuthorizationRequestRepository implements ReactiveAuthorizationRequestRepository<OAuth2AuthorizationRequest> {
+
+	private static final String DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME =
+			WebSessionOAuth2ReactiveAuthorizationRequestRepository.class.getName() +  ".AUTHORIZATION_REQUEST";
+
+	private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;
+
+	@Override
+	public Mono<OAuth2AuthorizationRequest> loadAuthorizationRequest(
+			ServerWebExchange exchange) {
+		String state = getStateParameter(exchange);
+		if (state == null) {
+			return Mono.empty();
+		}
+		return getStateToAuthorizationRequest(exchange, false)
+				.filter(stateToAuthorizationRequest -> stateToAuthorizationRequest.containsKey(state))
+				.map(stateToAuthorizationRequest -> stateToAuthorizationRequest.get(state));
+	}
+
+	@Override
+	public Mono<Void> saveAuthorizationRequest(
+			OAuth2AuthorizationRequest authorizationRequest, ServerWebExchange exchange) {
+		Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
+		return getStateToAuthorizationRequest(exchange, true)
+				.doOnNext(stateToAuthorizationRequest -> stateToAuthorizationRequest.put(authorizationRequest.getState(), authorizationRequest))
+				.then();
+	}
+
+	@Override
+	public Mono<OAuth2AuthorizationRequest> removeAuthorizationRequest(
+			ServerWebExchange exchange) {
+		String state = getStateParameter(exchange);
+		if (state == null) {
+			return Mono.empty();
+		}
+		return exchange.getSession()
+			.map(WebSession::getAttributes)
+			.handle((sessionAttrs, sink) -> {
+				Map<String, OAuth2AuthorizationRequest> stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest(sessionAttrs);
+				if (stateToAuthzRequest == null) {
+					sink.complete();
+					return;
+				}
+				OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state);
+				if (stateToAuthzRequest.isEmpty()) {
+					sessionAttrs.remove(this.sessionAttributeName);
+				}
+				sink.next(removedValue);
+			});
+	}
+
+	/**
+	 * Gets the state parameter from the {@link ServerHttpRequest}
+	 * @param exchange the exchange to use
+	 * @return the state parameter or null if not found
+	 */
+	private String getStateParameter(ServerWebExchange exchange) {
+		Assert.notNull(exchange, "exchange cannot be null");
+		return exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.STATE);
+	}
+
+	private Mono<Map<String, Object>> getSessionAttributes(ServerWebExchange exchange) {
+		return exchange.getSession().map(WebSession::getAttributes);
+	}
+
+	private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange, boolean create) {
+		Assert.notNull(exchange, "exchange cannot be null");
+
+		return getSessionAttributes(exchange)
+			.doOnNext(sessionAttrs -> {
+				if (create) {
+					sessionAttrs.putIfAbsent(this.sessionAttributeName, new HashMap<String, OAuth2AuthorizationRequest>());
+				}
+			})
+			.flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
+	}
+
+	private Map<String, OAuth2AuthorizationRequest> sessionAttrsMapStateToAuthorizationRequest(Map<String, Object> sessionAttrs) {
+		return (Map<String, OAuth2AuthorizationRequest>) sessionAttrs.get(this.sessionAttributeName);
+	}
+}

+ 224 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/WebSessionOAuth2ReactiveAuthorizationRequestRepositoryTests.java

@@ -0,0 +1,224 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.web;
+
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+import java.util.Map;
+
+import org.junit.Test;
+import org.springframework.http.codec.ServerCodecConfigurer;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.http.server.reactive.MockServerHttpResponse;
+import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebSession;
+import org.springframework.web.server.adapter.DefaultServerWebExchange;
+import org.springframework.web.server.i18n.AcceptHeaderLocaleContextResolver;
+import org.springframework.web.server.session.WebSessionManager;
+
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+/**
+ * @author Rob Winch
+ * @since 5.1
+ */
+public class WebSessionOAuth2ReactiveAuthorizationRequestRepositoryTests {
+
+	private WebSessionOAuth2ReactiveAuthorizationRequestRepository repository =
+			new WebSessionOAuth2ReactiveAuthorizationRequestRepository();
+
+	private OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+			.authorizationUri("https://example.com/oauth2/authorize")
+			.clientId("client-id")
+			.redirectUri("http://localhost/client-1")
+			.state("state")
+			.build();
+
+	private ServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")
+			.queryParam(OAuth2ParameterNames.STATE, "state"));
+
+	@Test
+	public void loadAuthorizatioNRequestWhenNullExchangeThenIllegalArgumentException() {
+		this.exchange = null;
+		assertThatThrownBy(() -> this.repository.loadAuthorizationRequest(this.exchange))
+			.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void loadAuthorizationRequestWhenNoSessionThenEmpty() {
+		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
+				.verifyComplete();
+
+		assertSessionStartedIs(false);
+	}
+
+	@Test
+	public void loadAuthorizationRequestWhenSessionAndNoRequestThenEmpty() {
+		Mono<OAuth2AuthorizationRequest> setAttrThenLoad = this.exchange.getSession()
+				.map(WebSession::getAttributes).doOnNext(attrs -> attrs.put("foo", "bar"))
+				.then(this.repository.loadAuthorizationRequest(this.exchange));
+
+		StepVerifier.create(setAttrThenLoad)
+				.verifyComplete();
+	}
+
+	@Test
+	public void loadAuthorizationRequestWhenNoStateParamThenEmpty() {
+		this.exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
+		Mono<OAuth2AuthorizationRequest> saveAndLoad = this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)
+				.then(this.repository.loadAuthorizationRequest(this.exchange));
+
+		StepVerifier.create(saveAndLoad)
+				.verifyComplete();
+	}
+
+	@Test
+	public void loadAuthorizationRequestWhenSavedThenAuthorizationRequest() {
+		Mono<OAuth2AuthorizationRequest> saveAndLoad = this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)
+				.then(this.repository.loadAuthorizationRequest(this.exchange));
+		StepVerifier.create(saveAndLoad)
+				.expectNext(this.authorizationRequest)
+				.verifyComplete();
+	}
+
+	@Test
+	public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() {
+		String oldState = "state0";
+		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, oldState).build();
+
+		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(oldState)
+				.build();
+
+		WebSessionManager sessionManager = e -> this.exchange.getSession();
+
+		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
+				ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
+				ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+
+		Mono<OAuth2AuthorizationRequest> saveAndSaveAndLoad = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
+				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
+				.then(this.repository.loadAuthorizationRequest(oldExchange));
+
+		StepVerifier.create(saveAndSaveAndLoad)
+				.expectNext(oldAuthorizationRequest)
+				.verifyComplete();
+
+		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
+				.expectNext(this.authorizationRequest)
+				.verifyComplete();
+	}
+
+	@Test
+	public void saveAuthorizationRequestWhenAuthorizationRequestNullThenThrowsIllegalArgumentException() {
+		this.authorizationRequest = null;
+		assertThatThrownBy(() -> this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
+				.isInstanceOf(IllegalArgumentException.class);
+		assertSessionStartedIs(false);
+
+	}
+
+	@Test
+	public void saveAuthorizationRequestWhenExchangeNullThenThrowsIllegalArgumentException() {
+		this.exchange = null;
+		assertThatThrownBy(() -> this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
+				.isInstanceOf(IllegalArgumentException.class);
+
+	}
+
+	@Test
+	public void removeAuthorizationRequestWhenExchangeNullThenThrowsIllegalArgumentException() {
+		this.exchange = null;
+		assertThatThrownBy(() -> this.repository.removeAuthorizationRequest(this.exchange))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void removeAuthorizationRequestWhenNotPresentThenThrowsIllegalArgumentException() {
+		StepVerifier.create(this.repository.removeAuthorizationRequest(this.exchange))
+				.verifyComplete();
+		assertSessionStartedIs(false);
+	}
+
+	@Test
+	public void removeAuthorizationRequestWhenPresentThenFoundAndRemoved() {
+		Mono<OAuth2AuthorizationRequest> saveAndRemove = this.repository
+				.saveAuthorizationRequest(this.authorizationRequest, this.exchange)
+				.then(this.repository.removeAuthorizationRequest(this.exchange));
+
+		StepVerifier.create(saveAndRemove).expectNext(this.authorizationRequest)
+				.verifyComplete();
+
+		StepVerifier.create(this.exchange.getSession()
+				.map(WebSession::getAttributes)
+				.map(Map::isEmpty))
+				.expectNext(true)
+				.verifyComplete();
+	}
+
+	@Test
+	public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() {
+		String oldState = "state0";
+		MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
+				.queryParam(OAuth2ParameterNames.STATE, oldState).build();
+
+		OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
+				.authorizationUri("https://example.com/oauth2/authorize")
+				.clientId("client-id")
+				.redirectUri("http://localhost/client-1")
+				.state(oldState)
+				.build();
+
+		WebSessionManager sessionManager = e -> this.exchange.getSession();
+
+		this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), sessionManager,
+				ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+		ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), sessionManager,
+				ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
+
+		Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
+				.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
+				.then(this.repository.removeAuthorizationRequest(this.exchange));
+
+		StepVerifier.create(saveAndSaveAndRemove)
+				.expectNext(this.authorizationRequest)
+				.verifyComplete();
+
+		StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
+				.verifyComplete();
+
+		StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange))
+				.expectNext(oldAuthorizationRequest)
+				.verifyComplete();
+	}
+
+	private void assertSessionStartedIs(boolean expected) {
+		Mono<Boolean> isStarted = this.exchange.getSession().map(WebSession::isStarted);
+		StepVerifier.create(isStarted)
+			.expectNext(expected)
+			.verifyComplete();
+	}
+}